def __init__( self, lang_dir: Pathlike, scripted_model_path: Optional[Pathlike] = None, model_dir: Optional[Pathlike] = None, average_epochs: Sequence[int] = (7, 8, 9), device: torch.device = 'cpu', sampling_rate: int = 16000, ): if isinstance(device, str): self.device = torch.device(device) self.sampling_rate = sampling_rate self.extractor = Fbank(FbankConfig(num_mel_bins=80)) self.lexicon = Lexicon(lang_dir) phone_ids = self.lexicon.phone_symbols() self.P = create_bigram_phone_lm(phone_ids) if model_dir is not None: # Read model from regular checkpoints, assume it's a Conformer self.model = Conformer(num_features=80, num_classes=len(phone_ids) + 1, num_decoder_layers=0) self.P.scores = torch.zeros_like(self.P.scores) self.model.P_scores = torch.nn.Parameter(self.P.scores.clone(), requires_grad=False) average_checkpoint(filenames=[ model_dir / f'epoch-{n}.pt' for n in average_epochs ], model=self.model) elif scripted_model_path is not None: # Read model from a serialized TorchScript module, no assumptions needed self.model = torch.jit.load(scripted_model_path) else: raise ValueError( "One of scripted_model_path or model_dir needs to be provided." ) # Freeze the params by default. for p in self.model.parameters(): p.requires_grad_(False) self.compiler = MmiTrainingGraphCompiler(lexicon=self.lexicon, device=self.device) self.HLG = k2.Fsa.from_dict(torch.load(lang_dir / 'HLG.pt')).to( self.device)
class ASR: """ This class is a high-level wrapper for K2 acoustic models that simplifies inference: reading models, computing posteriors, decoding, alignments, etc. Currently it will only work with the Conformer model with a very specific HMM topology. It could be the basis for a more generic entry point to Snow(Ice?)fall. """ def __init__( self, lang_dir: Pathlike, scripted_model_path: Optional[Pathlike] = None, model_dir: Optional[Pathlike] = None, average_epochs: Sequence[int] = (7, 8, 9), device: torch.device = 'cpu', sampling_rate: int = 16000, ): if isinstance(device, str): self.device = torch.device(device) self.sampling_rate = sampling_rate self.extractor = Fbank(FbankConfig(num_mel_bins=80)) self.lexicon = Lexicon(lang_dir) phone_ids = self.lexicon.phone_symbols() self.P = create_bigram_phone_lm(phone_ids) if model_dir is not None: # Read model from regular checkpoints, assume it's a Conformer self.model = Conformer(num_features=80, num_classes=len(phone_ids) + 1, num_decoder_layers=0) self.P.scores = torch.zeros_like(self.P.scores) self.model.P_scores = torch.nn.Parameter(self.P.scores.clone(), requires_grad=False) average_checkpoint(filenames=[ model_dir / f'epoch-{n}.pt' for n in average_epochs ], model=self.model) elif scripted_model_path is not None: # Read model from a serialized TorchScript module, no assumptions needed self.model = torch.jit.load(scripted_model_path) else: raise ValueError( "One of scripted_model_path or model_dir needs to be provided." ) # Freeze the params by default. for p in self.model.parameters(): p.requires_grad_(False) self.compiler = MmiTrainingGraphCompiler(lexicon=self.lexicon, device=self.device) self.HLG = k2.Fsa.from_dict(torch.load(lang_dir / 'HLG.pt')).to( self.device) def compute_features(self, cuts: Union[AnyCut, CutSet]) -> torch.Tensor: if isinstance(cuts, (Cut, MixedCut)): cuts = CutSet.from_cuts([cuts]) assert cuts[ 0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}' otf = OnTheFlyFeatures(self.extractor) # feats: (batch, seq_len, n_feats) feats, _ = otf(cuts) return feats def compute_posteriors(self, cuts: Union[AnyCut, CutSet]) -> torch.Tensor: """ Run the forward pass of the acoustic model and return a tensor representing a batch of phone posteriorgrams. """ # Extract feats # (batch, seq_len, num_feats) if isinstance(cuts, (Cut, MixedCut)): cuts = CutSet.from_cuts([cuts]) assert cuts[ 0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}' otf = OnTheFlyFeatures(self.extractor) # feats: (batch, seq_len, n_feats) feats, _ = otf(cuts) # feats: (batch, n_feats, seq_len) feats = feats.permute(0, 2, 1) # Compute AM posteriors # posteriors: (batch, n_phones, ~seq_len / 4) posteriors, _, _ = self.model(feats) # returns: (batch, ~seq_len / 4, n_phones) return posteriors.permute(0, 2, 1) def decode( self, cuts: Union[AnyCut, CutSet]) -> List[Tuple[List[str], List[str]]]: """ Perform decoding with an n-gram language model (HLG graph). Doesn't support rescoring at this time. """ if isinstance(cuts, (Cut, MixedCut)): cuts = CutSet.from_cuts([cuts]) word_results = [] # Hacky way to get batch quickly... we may need to improve on this. batch = K2SpeechRecognitionDataset(cuts, input_strategy=OnTheFlyFeatures( self.extractor), check_inputs=False)[list(cuts.ids)] features = batch['inputs'].permute(0, 2, 1).to( self.device) # (B, T, F) -> (B, F, T) supervision_segments, texts = encode_supervisions( batch['supervisions']) # Forward pass through the acoustic model posteriors, _, _ = self.model(features) posteriors = posteriors.permute(0, 2, 1) # (B, F, T) -> (B, T, F) # Wrapping into k2 "dense FSA" (representing PPG as a dense graph) dense_fsa_vec = k2.DenseFsaVec(posteriors, supervision_segments) # The actual decoding starts here: # First, we intersect the HLG and the PPG # with default pruning/beam search params from snowfall # The result is a batch of graphs (lattices) lattices = k2.intersect_dense_pruned(self.HLG, dense_fsa_vec, 20.0, 8, 30, 10000) # ... then we find the shortest paths in the lattices ... best_paths = k2.shortest_path(lattices, use_double_scores=True) # ... and convert them to words with a convenience wrapper from snowfall hyps = get_texts(best_paths, torch.arange(len(texts))) # Here we read out the words from the best path graphs for i in range(len(texts)): hyp_words = [self.lexicon.words.get(x) for x in hyps[i]] ref_words = texts[i].split(' ') word_results.append((ref_words, hyp_words)) return word_results def align(self, cuts: Union[AnyCut, CutSet]) -> torch.Tensor: """ Perform forced alignment and return a tensor that represents a batch of frame-level alignments: >>> alignments = torch.tensor([ ... [0, 0, 0, 1, 57, 57, 35, 35, 35, ...], ... [...], ... ... ... ]) :return: an int32 tensor with shape ``(batch_size, num_frames)``. """ # Extract feats # (batch, seq_len, num_feats) if isinstance(cuts, (Cut, MixedCut)): cuts = CutSet.from_cuts([cuts]) assert cuts[ 0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}' cuts = cuts.map_supervisions(self.normalize_text) otf = OnTheFlyFeatures(self.extractor) feats, _ = otf(cuts) feats = feats.permute(0, 2, 1) texts = [' '.join(s.text for s in cut.supervisions) for cut in cuts] # Compute AM posteriors # (batch, seq_len ~/ 4, num_phones) posteriors, _, _ = self.model(feats) # Note: we are using "dummy" supervisions so that the aligner also considers # the padding area. We can adjust that behaviour if needed by passing actual # supervision segments, but then we will have a ragged tensor (will need to # pad the alignments themselves). sups = self.dummy_supervisions(feats) posteriors_fsa = k2.DenseFsaVec(posteriors.permute(0, 2, 1), sups) # Intersection with ground truth transcript graphs num, den = self.compiler.compile(texts, self.P) alignment = k2.intersect_dense(num, posteriors_fsa, output_beam=10.0) best_path = k2.shortest_path(alignment, use_double_scores=True) # Retrieve sequences of phone IDs per frame # (batch, seq_len ~/ 4) -- dtype int32 (num phone labels) frame_labels = torch.stack( [best_path[i].labels[:-1] for i in range(best_path.shape[0])]) return frame_labels def align_ctm(self, cuts: Union[CutSet, AnyCut]) -> List[List[AlignmentItem]]: """ Perform forced alignment and parse the phones into a CTM-like format: >>> [[0.0, 0.12, 'SIL'], [0.12, 0.2, 'AH0'], ...] """ # TODO: I am not sure that this method is extracting the alignment 100% correctly: # need to revise... # TODO: when K2/Snowfall has a standard way of indicating what is silence, # or we update the model, update the constants below. EPS = 0 SIL = 1 non_speech = {EPS, SIL} def to_s(n: int) -> float: FRAME_SHIFT = 0.04 # 0.01 * 4 subsampling return round(n * FRAME_SHIFT, ndigits=3) if isinstance(cuts, (Cut, MixedCut)): cuts = CutSet.from_cuts([cuts]) # Uppercase and remove punctuation cuts = cuts.map_supervisions(self.normalize_text) alignments = self.align(cuts).tolist() ctm_alis = [] for cut, alignment in zip(cuts, alignments): # First we determine the silence regions at the beginning and the end: # we assume that every SIL and <eps> before the first phone, and after the last phone, # are representing silence. first_speech_idx = [ idx for idx, s in enumerate(alignment) if s not in non_speech ][0] last_speech_idx = [ idx for idx, s in reversed(list(enumerate(alignment))) if s not in non_speech ][0] speech_ali = alignment[first_speech_idx:last_speech_idx] ctm_ali = [ AlignmentItem(start=0.0, duration=to_s(first_speech_idx), symbol=self.lexicon.phones[SIL]) ] # Then, we iterate over the speech region: since the K2 model uses 2-state HMM # topology that allows blank (<eps>) to follow a phone symbol, we treat <eps> # as continuation of the "previous" phone. # TODO: I think this implementation is wrong in that it merges repeating phones... # Will fix. # TODO: I think it could be simplified by using some smart semi-ring and FSA operations... start = first_speech_idx prev_s = speech_ali[0] curr_s = speech_ali[0] cntr = 1 for s in speech_ali[1:]: curr_s = s if s != EPS else curr_s if curr_s != prev_s: ctm_ali.append( AlignmentItem(start=to_s(start), duration=to_s(cntr), symbol=self.lexicon.phones[prev_s])) start = start + cntr prev_s = curr_s cntr = 1 else: cntr += 1 if cntr: ctm_ali.append( AlignmentItem(start=to_s(start), duration=to_s(cntr), symbol=self.lexicon.phones[prev_s])) speech_end_timestamp = to_s(last_speech_idx) if speech_end_timestamp > cut.duration: logging.warning( f"speech_end_timestamp <= cut.duration. Skipping cut {cut.id}" ) ctm_alis.append(None) continue ctm_ali.append( AlignmentItem(start=speech_end_timestamp, duration=round(cut.duration - speech_end_timestamp, ndigits=8), symbol=self.lexicon.phones[SIL])) ctm_alis.append(ctm_ali) return ctm_alis def plot_alignments(self, cut: AnyCut): import matplotlib.pyplot as plt feats = self.compute_features(cut) phone_ids = self.align(cut) fig, axes = plt.subplots(2, squeeze=True, sharey=True, figsize=(10, 14)) axes[0].imshow(np.flipud(feats[0].T)) axes[1].imshow( torch.nn.functional.one_hot( phone_ids.repeat_interleave(4).to(torch.int64)).T) return fig, axes def plot_posteriors(self, cut: AnyCut): import matplotlib.pyplot as plt feats = self.compute_features(cut) posteriors = self.compute_posteriors(cut) fig, axes = plt.subplots(2, squeeze=True, sharey=True, figsize=(10, 14)) axes[0].imshow(np.flipud(feats[0].T)) axes[1].imshow(posteriors[0].exp().repeat_interleave(4, 1)) return fig, axes @staticmethod def dummy_supervisions(feats): def size_after_conv(size, num_layers=2): for i in range(num_layers): size = (size - 1) // 2 return size return torch.tensor([[ i, size_after_conv(2, num_layers=2), size_after_conv(feats.shape[2] - 2, num_layers=2) ] for i in range(feats.size(0))], dtype=torch.int32).clamp(min=0) @staticmethod def normalize_text(supervision): text = re.sub(r'[^\w\s]', '', supervision.text.upper()) return fastcopy(supervision, text=text)
def run(rank, world_size, args): ''' Args: rank: It is a value between 0 and `world_size-1`, which is passed automatically by `mp.spawn()` in :func:`main`. The node with rank 0 is responsible for saving checkpoint. world_size: Number of GPUs for DDP training. args: The return value of get_parser().parse_args() ''' model_type = args.model_type start_epoch = args.start_epoch num_epochs = args.num_epochs accum_grad = args.accum_grad den_scale = args.den_scale att_rate = args.att_rate use_pruned_intersect = args.use_pruned_intersect fix_random_seed(42) if world_size > 1: setup_dist(rank, world_size, args.master_port) suffix = '' if args.context_window is not None and args.context_window > 0: suffix = f'ac{args.context_window}' giga_subset = f'giga{args.subset}' exp_dir = Path( f'exp-{model_type}-mmi-att-sa-vgg-normlayer-{giga_subset}-{suffix}') setup_logger(f'{exp_dir}/log/log-train-{rank}') if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') else: tb_writer = None logging.info("Loading lexicon and symbol tables") lang_dir = Path('data/lang_nosp') lexicon = Lexicon(lang_dir) device_id = rank device = torch.device('cuda', device_id) if not Path(lang_dir / f'P_{args.subset}.pt').is_file(): logging.debug(f'Loading P from {lang_dir}/P_{args.subset}.fst.txt') with open(lang_dir / f'P_{args.subset}.fst.txt') as f: # P is not an acceptor because there is # a back-off state, whose incoming arcs # have label #0 and aux_label eps. P = k2.Fsa.from_openfst(f.read(), acceptor=False) phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') first_phone_disambig_id = find_first_disambig_symbol( phone_symbol_table) # P.aux_labels is not needed in later computations, so # remove it here. del P.aux_labels # CAUTION(fangjun): The following line is crucial. # Arcs entering the back-off state have label equal to #0. # We have to change it to 0 here. P.labels[P.labels >= first_phone_disambig_id] = 0 P = k2.remove_epsilon(P) P = k2.arc_sort(P) torch.save(P.as_dict(), lang_dir / f'P_{args.subset}.pt') else: logging.debug('Loading pre-compiled P') d = torch.load(lang_dir / f'P_{args.subset}.pt') P = k2.Fsa.from_dict(d) graph_compiler = MmiTrainingGraphCompiler( lexicon=lexicon, P=P, device=device, ) phone_ids = lexicon.phone_symbols() gigaspeech = GigaSpeechAsrDataModule(args) train_dl = gigaspeech.train_dataloaders() valid_dl = gigaspeech.valid_dataloaders() if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) if use_pruned_intersect: logging.info('Use pruned intersect for den_lats') else: logging.info("Don't use pruned intersect for den_lats") logging.info("About to create model") if att_rate != 0.0: num_decoder_layers = 6 else: num_decoder_layers = 0 if model_type == "transformer": model = Transformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers, vgg_frontend=True) elif model_type == "conformer": model = Conformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers, vgg_frontend=True, is_espnet_structure=True) elif model_type == "contextnet": model = ContextNet(num_features=80, num_classes=len(phone_ids) + 1) # +1 for the blank symbol else: raise NotImplementedError("Model of type " + str(model_type) + " is not implemented") if args.torchscript: logging.info('Applying TorchScript to model...') model = torch.jit.script(model) model.to(device) describe(model) if world_size > 1: model = DDP(model, device_ids=[rank]) # Now for the alignment model, if any if args.use_ali_model: ali_model = TdnnLstm1b( num_features=80, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4) ali_model_fname = Path( f'exp-lstm-adam-ctc-musan/epoch-{args.ali_model_epoch}.pt') assert ali_model_fname.is_file(), \ f'ali model filename {ali_model_fname} does not exist!' ali_model.load_state_dict( torch.load(ali_model_fname, map_location='cpu')['state_dict']) ali_model.to(device) ali_model.eval() ali_model.requires_grad_(False) logging.info(f'Use ali_model: {ali_model_fname}') else: ali_model = None logging.info('No ali_model') optimizer = Noam(model.parameters(), model_size=args.attention_dim, factor=args.lr_factor, warm_step=args.warm_step, weight_decay=args.weight_decay) scaler = GradScaler(enabled=args.amp) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer, scaler=scaler) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info( f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}" ) for epoch in range(start_epoch, num_epochs): train_dl.sampler.set_epoch(epoch) curr_learning_rate = optimizer._rate if tb_writer is not None: tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch( dataloader=train_dl, valid_dataloader=valid_dl, model=model, ali_model=ali_model, device=device, graph_compiler=graph_compiler, use_pruned_intersect=use_pruned_intersect, optimizer=optimizer, accum_grad=accum_grad, den_scale=den_scale, att_rate=att_rate, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, world_size=world_size, scaler=scaler) # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, optimizer=None, scheduler=None, scaler=None, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train, local_rank=rank, torchscript=args.torchscript_epoch != -1 and epoch >= args.torchscript_epoch) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch, local_rank=rank) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, optimizer=optimizer, scheduler=None, scaler=scaler, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train, local_rank=rank, torchscript=args.torchscript_epoch != -1 and epoch >= args.torchscript_epoch) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch, local_rank=rank) logging.warning('Done') if world_size > 1: torch.distributed.barrier() cleanup_dist()
def run(rank, world_size, args): ''' Args: rank: It is a value between 0 and `world_size-1`, which is passed automatically by `mp.spawn()` in :func:`main`. The node with rank 0 is responsible for saving checkpoint. world_size: Number of GPUs for DDP training. args: The return value of get_parser().parse_args() ''' model_type = args.model_type start_epoch = args.start_epoch num_epochs = args.num_epochs accum_grad = args.accum_grad den_scale = args.den_scale att_rate = args.att_rate fix_random_seed(42) setup_dist(rank, world_size, args.master_port) exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-vgg') setup_logger(f'{exp_dir}/log/log-train-{rank}') if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') else: tb_writer = None # tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') if args.tensorboard and rank == 0 else None logging.info("Loading lexicon and symbol tables") lang_dir = Path('data/lang_nosp') lexicon = Lexicon(lang_dir) device_id = rank device = torch.device('cuda', device_id) graph_compiler = MmiTrainingGraphCompiler( lexicon=lexicon, device=device, ) phone_ids = lexicon.phone_symbols() P = create_bigram_phone_lm(phone_ids) P.scores = torch.zeros_like(P.scores) P = P.to(device) mls = MLSAsrDataModule(args) train_dl = mls.train_dataloaders() valid_dl = mls.valid_dataloaders() if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) logging.info("About to create model") if att_rate != 0.0: num_decoder_layers = 6 else: num_decoder_layers = 0 if model_type == "transformer": model = Transformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers, vgg_frontend=True) elif model_type == "conformer": model = Conformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers, vgg_frontend=True) elif model_type == "contextnet": model = ContextNet(num_features=80, num_classes=len(phone_ids) + 1) # +1 for the blank symbol else: raise NotImplementedError("Model of type " + str(model_type) + " is not implemented") model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True) model.to(device) describe(model) model = DDP(model, device_ids=[rank]) # Now for the aligment model, if any if args.use_ali_model: ali_model = TdnnLstm1b( num_features=80, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4) ali_model_fname = Path( f'exp-lstm-adam-ctc-musan/epoch-{args.ali_model_epoch}.pt') assert ali_model_fname.is_file(), \ f'ali model filename {ali_model_fname} does not exist!' ali_model.load_state_dict( torch.load(ali_model_fname, map_location='cpu')['state_dict']) ali_model.to(device) ali_model.eval() ali_model.requires_grad_(False) logging.info(f'Use ali_model: {ali_model_fname}') else: ali_model = None logging.info('No ali_model') optimizer = Noam(model.parameters(), model_size=args.attention_dim, factor=args.lr_factor, warm_step=args.warm_step, weight_decay=args.weight_decay) scaler = GradScaler(enabled=args.amp) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer, scaler=scaler) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info( f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}" ) for epoch in range(start_epoch, num_epochs): train_dl.sampler.set_epoch(epoch) curr_learning_rate = optimizer._rate if tb_writer is not None: tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch( dataloader=train_dl, valid_dataloader=valid_dl, model=model, ali_model=ali_model, P=P, device=device, graph_compiler=graph_compiler, optimizer=optimizer, accum_grad=accum_grad, den_scale=den_scale, att_rate=att_rate, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, world_size=world_size, scaler=scaler) # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, optimizer=None, scheduler=None, scaler=None, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train, local_rank=rank) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch, local_rank=rank) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, optimizer=optimizer, scheduler=None, scaler=scaler, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train, local_rank=rank) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch, local_rank=rank) logging.warning('Done') torch.distributed.barrier() cleanup_dist()
def main(): args = get_parser().parse_args() exp_dir = Path('exp-lstm-adam-mmi-bigram-musan-dist') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') lexicon = Lexicon(lang_dir) phone_ids = lexicon.phone_symbols() P = create_bigram_phone_lm(phone_ids) phone_ids_with_blank = [0] + phone_ids ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) logging.debug("About to load model") # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N # device = torch.device('cuda', 1) device = torch.device('cuda') model = TdnnLstm1b( num_features=80, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=3) model.P_scores = torch.nn.Parameter(P.scores.clone(), requires_grad=False) checkpoint = os.path.join(exp_dir, f'epoch-{args.epoch}.pt') load_checkpoint(checkpoint, model) model.to(device) model.eval() assert P.requires_grad is False P.scores = model.P_scores.cpu() print_transition_probabilities(P, lexicon.phones, phone_ids, filename='model_P_scores.txt') P.set_scores_stochastic_(model.P_scores) print_transition_probabilities(P, lexicon.phones, phone_ids, filename='P_scores.txt') if not os.path.exists(lang_dir / 'HLG.pt'): logging.debug("Loading L_disambig.fst.txt") with open(lang_dir / 'L_disambig.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) logging.debug("Loading G.fst.txt") with open(lang_dir / 'G.fst.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) first_phone_disambig_id = find_first_disambig_symbol(lexicon.phones) first_word_disambig_id = find_first_disambig_symbol(lexicon.words) HLG = compile_HLG(L=L, G=G, H=ctc_topo, labels_disambig_id_start=first_phone_disambig_id, aux_labels_disambig_id_start=first_word_disambig_id) torch.save(HLG.as_dict(), lang_dir / 'HLG.pt') else: logging.debug("Loading pre-compiled LG") d = torch.load(lang_dir / 'HLG.pt') HLG = k2.Fsa.from_dict(d) # load dataset feature_dir = Path('exp/data') logging.debug("About to get test cuts") cuts_test = CutSet.from_json(feature_dir / 'cuts_test-clean.json.gz') logging.info("About to create test dataset") test = K2SpeechRecognitionDataset(cuts_test) sampler = SingleCutSampler(cuts_test, max_frames=40000) logging.info("About to create test dataloader") test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) # if not torch.cuda.is_available(): # logging.error('No GPU detected!') # sys.exit(-1) logging.debug("convert HLG to device") HLG = HLG.to(device) HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0) HLG.requires_grad_(False) logging.debug("About to decode") results = decode(dataloader=test_dl, model=model, device=device, HLG=HLG, symbols=lexicon.words) s = '' for ref, hyp in results: s += f'ref={ref}\n' s += f'hyp={hyp}\n' logging.info(s) # compute WER dists = [edit_distance(r, h) for r, h in results] errors = { key: sum(dist[key] for dist in dists) for key in ['sub', 'ins', 'del', 'total'] } total_words = sum(len(ref) for ref, _ in results) # Print Kaldi-like message: # %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ] logging.info( f'%WER {errors["total"] / total_words:.2%} ' f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]' )
def main(): args = get_parser().parse_args() print('World size:', args.world_size, 'Rank:', args.local_rank) setup_dist(rank=args.local_rank, world_size=args.world_size, master_port=args.master_port) fix_random_seed(42) start_epoch = 0 num_epochs = 10 use_adam = True exp_dir = f'exp-lstm-adam-mmi-bigram-musan' setup_logger('{}/log/log-train'.format(exp_dir)) tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') lexicon = Lexicon(lang_dir) device_id = args.local_rank device = torch.device('cuda', device_id) phone_ids = lexicon.phone_symbols() if not Path(lang_dir / 'P.pt').is_file(): logging.debug(f'Loading P from {lang_dir}/P.fst.txt') with open(lang_dir / 'P.fst.txt') as f: # P is not an acceptor because there is # a back-off state, whose incoming arcs # have label #0 and aux_label eps. P = k2.Fsa.from_openfst(f.read(), acceptor=False) phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') first_phone_disambig_id = find_first_disambig_symbol( phone_symbol_table) # P.aux_labels is not needed in later computations, so # remove it here. del P.aux_labels # CAUTION(fangjun): The following line is crucial. # Arcs entering the back-off state have label equal to #0. # We have to change it to 0 here. P.labels[P.labels >= first_phone_disambig_id] = 0 P = k2.remove_epsilon(P) P = k2.arc_sort(P) torch.save(P.as_dict(), lang_dir / 'P.pt') else: logging.debug('Loading pre-compiled P') d = torch.load(lang_dir / 'P.pt') P = k2.Fsa.from_dict(d) graph_compiler = MmiTrainingGraphCompiler( lexicon=lexicon, P=P, device=device, ) # load dataset feature_dir = Path('exp/data') logging.info("About to get train cuts") cuts_train = CutSet.from_json(feature_dir / 'cuts_train.json.gz') logging.info("About to get dev cuts") cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev.json.gz') logging.info("About to get Musan cuts") cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") train = K2SpeechRecognitionDataset(cuts_train, cut_transforms=[ CutConcatenate(), CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20)) ]) train_sampler = SingleCutSampler( cuts_train, max_frames=40000, shuffle=True, ) logging.info("About to create train dataloader") train_dl = torch.utils.data.DataLoader(train, sampler=train_sampler, batch_size=None, num_workers=4) logging.info("About to create dev dataset") validate = K2SpeechRecognitionDataset(cuts_dev) valid_sampler = SingleCutSampler(cuts_dev, max_frames=12000) logging.info("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader(validate, sampler=valid_sampler, batch_size=None, num_workers=1) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) logging.info("About to create model") device_id = 0 device = torch.device('cuda', device_id) model = TdnnLstm1b( num_features=40, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=3) model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True) model.to(device) describe(model) if use_adam: learning_rate = 1e-3 weight_decay = 5e-4 optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) # Equivalent to the following in the epoch loop: # if epoch > 6: # curr_learning_rate *= 0.8 lr_scheduler = optim.lr_scheduler.LambdaLR( optimizer, lambda ep: 1.0 if ep < 7 else 0.8**(ep - 6)) else: learning_rate = 5e-5 weight_decay = 1e-5 momentum = 0.9 lr_schedule_gamma = 0.7 optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) lr_scheduler = optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=lr_schedule_gamma) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer, scheduler=lr_scheduler) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info( f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}" ) for epoch in range(start_epoch, num_epochs): train_sampler.set_epoch(epoch) # LR scheduler can hold multiple learning rates for multiple parameter groups; # For now we report just the first LR which we assume concerns most of the parameters. curr_learning_rate = lr_scheduler.get_last_lr()[0] tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch( dataloader=train_dl, valid_dataloader=valid_dl, model=model, device=device, graph_compiler=graph_compiler, optimizer=optimizer, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, ) lr_scheduler.step() # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, model=model, optimizer=None, scheduler=None, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, model=model, optimizer=optimizer, scheduler=lr_scheduler, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) logging.warning('Done')
def main(): args = get_parser().parse_args() print('World size:', args.world_size, 'Rank:', args.local_rank) setup_dist(rank=args.local_rank, world_size=args.world_size) fix_random_seed(42) start_epoch = 0 num_epochs = 10 use_adam = True exp_dir = f'exp-lstm-adam-mmi-bigram-musan-dist' setup_logger('{}/log/log-train'.format(exp_dir), use_console=args.local_rank == 0) tb_writer = SummaryWriter( log_dir=f'{exp_dir}/tensorboard') if args.local_rank == 0 else None # load L, G, symbol_table lang_dir = Path('data/lang_nosp') lexicon = Lexicon(lang_dir) device_id = args.local_rank device = torch.device('cuda', device_id) graph_compiler = MmiTrainingGraphCompiler( lexicon=lexicon, device=device, ) phone_ids = lexicon.phone_symbols() P = create_bigram_phone_lm(phone_ids) P.scores = torch.zeros_like(P.scores) P = P.to(device) # load dataset feature_dir = Path('exp/data') logging.info("About to get train cuts") cuts_train = CutSet.from_json(feature_dir / 'cuts_train-clean-100.json.gz') logging.info("About to get dev cuts") cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev-clean.json.gz') logging.info("About to get Musan cuts") cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] if not args.bucketing_sampler: # We don't mix concatenating the cuts and bucketing # Here we insert concatenation before mixing so that the # noises from Musan are mixed onto almost-zero-energy # padding frames. transforms = [CutConcatenate(duration_factor=1)] + transforms train = K2SpeechRecognitionDataset(cuts_train, cut_transforms=transforms) if args.bucketing_sampler: logging.info('Using BucketingSampler.') train_sampler = BucketingSampler(cuts_train, max_frames=40000, shuffle=True, num_buckets=30) else: logging.info('Using regular sampler with cut concatenation.') train_sampler = SingleCutSampler( cuts_train, max_frames=30000, shuffle=True, ) logging.info("About to create train dataloader") train_dl = torch.utils.data.DataLoader(train, sampler=train_sampler, batch_size=None, num_workers=4) logging.info("About to create dev dataset") validate = K2SpeechRecognitionDataset(cuts_dev) # Note: we explicitly set world_size to 1 to disable the auto-detection of # distributed training inside the sampler. This way, every GPU will # perform the computation on the full dev set. It is a bit wasteful, # but unfortunately loss aggregation between multiple processes with # torch.distributed.all_reduce() tends to hang indefinitely inside # NCCL after ~3000 steps. With the current approach, we can still report # the loss on the full validation set. valid_sampler = SingleCutSampler(cuts_dev, max_frames=90000, world_size=1, rank=0) logging.info("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader(validate, sampler=valid_sampler, batch_size=None, num_workers=1) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) logging.info("About to create model") model = TdnnLstm1b( num_features=80, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=3) model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True) model.to(device) describe(model) if use_adam: learning_rate = 1e-3 weight_decay = 5e-4 optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) # Equivalent to the following in the epoch loop: # if epoch > 6: # curr_learning_rate *= 0.8 lr_scheduler = optim.lr_scheduler.LambdaLR( optimizer, lambda ep: 1.0 if ep < 7 else 0.8**(ep - 6)) else: learning_rate = 5e-5 weight_decay = 1e-5 momentum = 0.9 lr_schedule_gamma = 0.7 optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) lr_scheduler = optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=lr_schedule_gamma) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer, scheduler=lr_scheduler) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info( f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}" ) if args.world_size > 1: logging.info( 'Using DistributedDataParallel in training. ' 'The reported loss, num_frames, etc. for training steps include ' 'only the batches seen in the master process (the actual loss ' 'includes batches from all GPUs, and the actual num_frames is ' f'approx. {args.world_size}x larger.') # For now do not sync BatchNorm across GPUs due to NCCL hanging in all_gather... # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) for epoch in range(start_epoch, num_epochs): train_sampler.set_epoch(epoch) # LR scheduler can hold multiple learning rates for multiple parameter groups; # For now we report just the first LR which we assume concerns most of the parameters. curr_learning_rate = lr_scheduler.get_last_lr()[0] if tb_writer is not None: tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch( dataloader=train_dl, valid_dataloader=valid_dl, model=model, P=P, device=device, graph_compiler=graph_compiler, optimizer=optimizer, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, ) lr_scheduler.step() # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, model=model, optimizer=None, scheduler=None, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, local_rank=args.local_rank, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, model=model, optimizer=optimizer, scheduler=lr_scheduler, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, local_rank=args.local_rank, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) logging.warning('Done') cleanup_dist()
def main(): args = get_parser().parse_args() exp_dir = Path('exp-lstm-adam-mmi-bigram-musan-dist') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') lexicon = Lexicon(lang_dir) phone_ids = lexicon.phone_symbols() phone_ids_with_blank = [0] + phone_ids ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) logging.debug("About to load model") # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N # device = torch.device('cuda', 1) device = torch.device('cuda') model = TdnnLstm1b( num_features=80, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=3) checkpoint = os.path.join(exp_dir, f'epoch-{args.epoch}.pt') load_checkpoint(checkpoint, model) model.to(device) model.eval() if not os.path.exists(lang_dir / 'HLG.pt'): logging.debug("Loading L_disambig.fst.txt") with open(lang_dir / 'L_disambig.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) logging.debug("Loading G.fst.txt") with open(lang_dir / 'G.fst.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) first_phone_disambig_id = find_first_disambig_symbol(lexicon.phones) first_word_disambig_id = find_first_disambig_symbol(lexicon.words) HLG = compile_HLG(L=L, G=G, H=ctc_topo, labels_disambig_id_start=first_phone_disambig_id, aux_labels_disambig_id_start=first_word_disambig_id) torch.save(HLG.as_dict(), lang_dir / 'HLG.pt') else: logging.debug("Loading pre-compiled LG") d = torch.load(lang_dir / 'HLG.pt') HLG = k2.Fsa.from_dict(d) # load dataset feature_dir = Path('exp/data') logging.debug("About to get test cuts") cuts_test = CutSet.from_json(feature_dir / 'cuts_test-clean.json.gz') logging.info("About to create test dataset") test = K2SpeechRecognitionDataset(cuts_test) sampler = SingleCutSampler(cuts_test, max_frames=40000) logging.info("About to create test dataloader") test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) # if not torch.cuda.is_available(): # logging.error('No GPU detected!') # sys.exit(-1) logging.debug("convert HLG to device") HLG = HLG.to(device) HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0) HLG.requires_grad_(False) logging.debug("About to decode") results = decode(dataloader=test_dl, model=model, device=device, HLG=HLG, symbols=lexicon.words) test_set = 'test-clean' recog_path = exp_dir / f'recogs-{test_set}.txt' store_transcripts(path=recog_path, texts=results) logging.info(f'The transcripts are stored in {recog_path}') # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = exp_dir / f'errs-{test_set}.txt' with open(errs_filename, 'w') as f: wer = write_error_stats(f, f'{test_set}', results) logging.info(f'The error stats are stored in {errs_filename}')
def run(rank, world_size, args): ''' Args: rank: It is a value between 0 and `world_size-1`, which is passed automatically by `mp.spawn()` in :func:`main`. The node with rank 0 is responsible for saving checkpoint. world_size: Number of GPUs for DDP training. args: The return value of get_parser().parse_args() ''' model_type = args.model_type start_epoch = args.start_epoch num_epochs = args.num_epochs accum_grad = args.accum_grad den_scale = args.den_scale att_rate = args.att_rate fix_random_seed(42) setup_dist(rank, world_size, args.master_port) exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa') setup_logger(f'{exp_dir}/log/log-train-{rank}') if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') else: tb_writer = None # tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') if args.tensorboard and rank == 0 else None logging.info("Loading lexicon and symbol tables") lang_dir = Path('data/lang_nosp') lexicon = Lexicon(lang_dir) device_id = rank device = torch.device('cuda', device_id) graph_compiler = MmiTrainingGraphCompiler( lexicon=lexicon, device=device, ) phone_ids = lexicon.phone_symbols() P = create_bigram_phone_lm(phone_ids) P.scores = torch.zeros_like(P.scores) P = P.to(device) librispeech = LibriSpeechAsrDataModule(args) train_dl = librispeech.train_dataloaders() valid_dl = librispeech.valid_dataloaders() if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) logging.info("About to create model") if att_rate != 0.0: num_decoder_layers = 6 else: num_decoder_layers = 0 if model_type == "transformer": model = Transformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) else: model = Conformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True) model.to(device) describe(model) model = DDP(model, device_ids=[rank]) optimizer = Noam(model.parameters(), model_size=args.attention_dim, factor=1.0, warm_step=args.warm_step) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info( f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}" ) for epoch in range(start_epoch, num_epochs): train_dl.sampler.set_epoch(epoch) curr_learning_rate = optimizer._rate if tb_writer is not None: tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch( dataloader=train_dl, valid_dataloader=valid_dl, model=model, P=P, device=device, graph_compiler=graph_compiler, optimizer=optimizer, accum_grad=accum_grad, den_scale=den_scale, att_rate=att_rate, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, world_size=world_size, ) # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, optimizer=None, scheduler=None, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train, local_rank=rank) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch, local_rank=rank) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, optimizer=optimizer, scheduler=None, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train, local_rank=rank) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch, local_rank=rank) logging.warning('Done') torch.distributed.barrier() # NOTE: The training process is very likely to hang at this point. # If you press ctrl + c, your GPU memory will not be freed. # To free you GPU memory, you can run: # # $ ps aux | grep multi # # And it will print something like below: # # kuangfa+ 430518 98.9 0.6 57074236 3425732 pts/21 Rl Apr02 639:01 /root/fangjun/py38/bin/python3 -c from multiprocessing.spawn # # You can kill the process manually by: # # $ kill -9 430518 # # And you will see that your GPU is now not occupied anymore. cleanup_dist()