def test1(self): s = ''' 0 4 1 1 0 1 1 1 1 2 0 2 1 3 0 3 1 4 0 2 2 7 0 4 3 7 0 5 4 6 1 2 4 6 0 3 4 8 1 3 4 9 -1 2 5 9 -1 4 6 9 -1 3 7 9 -1 5 8 9 -1 6 9 ''' fsa = k2.Fsa.from_str(s) prop = fsa.properties self.assertFalse(prop & k2.fsa_properties.EPSILON_FREE) dest = k2.remove_epsilon(fsa) prop = dest.properties self.assertTrue(prop & k2.fsa_properties.EPSILON_FREE) log_semiring = False self.assertTrue(k2.is_rand_equivalent(fsa, dest, log_semiring))
def test1(self): if not torch.cuda.is_available(): return if not k2.with_cuda: return device = torch.device('cuda', 0) s = ''' 0 1 0 1 1 1 2 0 2 1 2 3 0 3 1 3 4 4 4 1 3 5 -1 5 1 4 5 -1 6 1 5 ''' fsa = k2.Fsa.from_str(s, aux_label_names=['foo']).to(device) filler = 2 fsa.foo_filler = filler print("Before removing epsilons: ", fsa) prop = fsa.properties self.assertFalse(prop & k2.fsa_properties.EPSILON_FREE) dest = k2.remove_epsilon(fsa) prop = dest.properties self.assertTrue(prop & k2.fsa_properties.EPSILON_FREE) log_semiring = False self.assertTrue(k2.is_rand_equivalent(fsa, dest, log_semiring)) print("After removing epsilons: ", dest) assert torch.where(dest.foo.values == filler)[0].numel() == 0
def get_hierarchical_targets(ys: List[List[int]], lexicon: k2.Fsa) -> List[Tensor]: """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts). Args: ys: Word level transcripts. lexicon: Its labels are words, while its aux_labels are phones. Returns: List[Tensor]: Phone level transcripts. """ if lexicon is None: return ys else: L_inv = lexicon n_batch = len(ys) indices = torch.tensor(range(n_batch)) transcripts = k2.create_fsa_vec([k2.linear_fsa(x) for x in ys]) transcripts_lexicon = k2.intersect(transcripts, L_inv) transcripts_lexicon = k2.arc_sort(k2.connect(transcripts_lexicon)) transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon) transcripts_lexicon = k2.shortest_path(transcripts_lexicon, use_double_scores=True) ys = get_texts(transcripts_lexicon, indices) ys = [torch.tensor(y) for y in ys] return ys
def test_autograd(self): if not torch.cuda.is_available(): return if not k2.with_cuda: return device = torch.device('cuda', 0) s = ''' 0 1 0 0.1 0 1 1 0.2 1 2 -1 0.3 2 ''' fsa = k2.Fsa.from_str(s).to(device).requires_grad_(True) ans = k2.remove_epsilon(fsa) print("ans = ", ans) # arc map is [[1] [0 2] [2]] scale = torch.tensor([10, 20, 30]).to(device) (ans.scores * scale).sum().backward() expected_grad = torch.empty_like(fsa.scores) expected_grad[0] = scale[1] expected_grad[1] = scale[0] expected_grad[2] = scale[1] + scale[2] print("fsa.grad = ", fsa.grad) print("expected_grad = ", expected_grad) assert torch.all(torch.eq(fsa.grad, expected_grad))
def test1(self): if not torch.cuda.is_available(): return if not k2.with_cuda: return device = torch.device('cuda', 0) s = ''' 0 1 0 1 1 1 2 0 2 1 2 3 0 3 1 3 4 4 4 1 3 5 -1 5 1 4 5 -1 6 1 5 ''' fsa = k2.Fsa.from_str(s, num_aux_labels=1).to(device) print(fsa.aux_labels) prop = fsa.properties self.assertFalse(prop & k2.fsa_properties.EPSILON_FREE) dest = k2.remove_epsilon(fsa) prop = dest.properties self.assertTrue(prop & k2.fsa_properties.EPSILON_FREE) log_semiring = False self.assertTrue(k2.is_rand_equivalent(fsa, dest, log_semiring)) # just make sure that it runs. dest2 = k2.remove_epsilon_and_add_self_loops(fsa) dest3 = k2.remove_epsilon(dest2) self.assertTrue( k2.is_rand_equivalent(dest, dest3, log_semiring, treat_epsilons_specially=False)) self.assertFalse( k2.is_rand_equivalent(dest, dest2, log_semiring, treat_epsilons_specially=False, npath=10000)) self.assertTrue( k2.is_rand_equivalent(dest, dest2, log_semiring, treat_epsilons_specially=True))
def test_autograd(self): s = ''' 0 1 0 0.1 0 1 1 0.2 1 2 -1 0.3 2 ''' src = k2.Fsa.from_str(s).requires_grad_(True) scores_copy = src.scores.detach().clone().requires_grad_(True) src.attr1 = "hello" src.attr2 = "k2" float_attr = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, requires_grad=True) src.float_attr = float_attr.detach().clone().requires_grad_(True) src.int_attr = torch.tensor([1, 2, 3], dtype=torch.int32) src.ragged_attr = k2.RaggedTensor([[10, 20], [30, 40, 50], [60, 70]]) dest = k2.remove_epsilon(src) # arc map is [[1] [0 2] [2]] assert dest.attr1 == src.attr1 assert dest.attr2 == src.attr2 expected_int_attr = k2.RaggedTensor([[2], [1, 3], [3]]) assert dest.int_attr == expected_int_attr expected_ragged_attr = k2.RaggedTensor([[30, 40, 50], [10, 20, 60, 70], [60, 70]]) assert dest.ragged_attr == expected_ragged_attr expected_float_attr = torch.empty_like(dest.float_attr) expected_float_attr[0] = float_attr[1] expected_float_attr[1] = float_attr[0] + float_attr[2] expected_float_attr[2] = float_attr[2] assert torch.all(torch.eq(dest.float_attr, expected_float_attr)) expected_scores = torch.empty_like(dest.scores) expected_scores[0] = scores_copy[1] expected_scores[1] = scores_copy[0] + scores_copy[2] expected_scores[2] = scores_copy[2] assert torch.all(torch.eq(dest.scores, expected_scores)) scale = torch.tensor([10, 20, 30]).to(float_attr) (dest.float_attr * scale).sum().backward() (expected_float_attr * scale).sum().backward() assert torch.all(torch.eq(src.float_attr.grad, float_attr.grad)) (dest.scores * scale).sum().backward() (expected_scores * scale).sum().backward() assert torch.all(torch.eq(src.scores.grad, scores_copy.grad))
def test_autograd(self): s = ''' 0 1 0 0.1 0 1 1 0.2 1 2 -1 0.3 2 ''' fsa = k2.Fsa.from_str(s).requires_grad_(True) ans = k2.remove_epsilon(fsa) # arc map is [[1] [0 2] [2]] scale = torch.tensor([10, 20, 30]) (ans.scores * scale).sum().backward() expected_grad = torch.empty_like(fsa.scores) expected_grad[0] = scale[1] expected_grad[1] = scale[0] expected_grad[2] = scale[1] + scale[2] assert torch.all(torch.eq(fsa.grad, expected_grad))
def test_composition_equivalence(self): index = _generate_fsa_vec() index = k2.arc_sort(k2.connect(k2.remove_epsilon(index))) src = _generate_fsa_vec() replace = k2.replace_fsa(src, index, 1) replace = k2.top_sort(replace) f_fsa = _construct_f(src) f_fsa = k2.arc_sort(f_fsa) intersect = k2.intersect(index, f_fsa, treat_epsilons_specially=True) intersect = k2.invert(intersect) intersect = k2.top_sort(intersect) delattr(intersect, 'aux_labels') assert k2.is_rand_equivalent(replace, intersect, log_semiring=True, delta=1e-3)
def get_hierarchical_targets(ys: List[List[int]], lexicon: k2.Fsa) -> List[Tensor]: """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts). Args: ys: Word level transcripts. lexicon: Its labels are words, while its aux_labels are phones. Returns: List[Tensor]: Phone level transcripts. """ if lexicon is None: return ys else: L_inv = lexicon n_batch = len(ys) indices = torch.tensor(range(n_batch)) device = L_inv.device transcripts = k2.create_fsa_vec( [k2.linear_fsa(x, device=device) for x in ys]) transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts) transcripts_lexicon = k2.intersect(L_inv, transcripts_with_self_loops, treat_epsilons_specially=False) # Don't call invert_() above because we want to return phone IDs, # which is the `aux_labels` of transcripts_lexicon transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon) transcripts_lexicon = k2.top_sort(transcripts_lexicon) transcripts_lexicon = k2.shortest_path(transcripts_lexicon, use_double_scores=True) ys = get_texts(transcripts_lexicon, indices) ys = [torch.tensor(y) for y in ys] return ys
def test_random(self): while True: fsa = k2.random_fsa(max_symbol=20, min_num_arcs=50, max_num_arcs=500) fsa = k2.arc_sort(k2.connect(k2.remove_epsilon(fsa))) prob = fsa.properties # we need non-deterministic fsa if not prob & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC: break log_semiring = False # test weight pushing tropical dest_max = k2.determinize( fsa, k2.DeterminizeWeightPushingType.kTropicalWeightPushing) self.assertTrue( k2.is_rand_equivalent(fsa, dest_max, log_semiring, delta=1e-3)) # test weight pushing log dest_log = k2.determinize( fsa, k2.DeterminizeWeightPushingType.kLogWeightPushing) self.assertTrue( k2.is_rand_equivalent(fsa, dest_log, log_semiring, delta=1e-3))
def intersect(self, lats: Fsa) -> 'Nbest': '''Intersect this Nbest object with a lattice and get 1-best path from the resulting FsaVec. Caution: We assume FSAs in `self.fsa` don't have epsilon self-loops. We also assume `self.fsa.labels` and `lats.labels` are token IDs. Args: lats: An FsaVec. It can be the return value of :func:`whole_lattice_rescoring`. Returns: Return a new Nbest. This new Nbest shares the same shape with `self`, while its `fsa` is the 1-best path from intersecting `self.fsa` and `lats. ''' assert self.fsa.device == lats.device, \ f'{self.fsa.device} vs {lats.device}' assert len(lats.shape) == 3, f'{lats.shape}' assert lats.arcs.dim0() == self.shape.dim0(), \ f'{lats.arcs.dim0()} vs {self.shape.dim0()}' lats = k2.arc_sort(lats) # no-op if lats is already arc sorted fsas_with_epsilon_loops = k2.add_epsilon_self_loops(self.fsa) path_to_seq_map = self.shape.row_ids(1) ans_lats = k2.intersect_device(a_fsas=lats, b_fsas=fsas_with_epsilon_loops, b_to_a_map=path_to_seq_map, sorted_match_a=True) one_best = k2.shortest_path(ans_lats, use_double_scores=True) one_best = k2.remove_epsilon(one_best) return Nbest(fsa=one_best, shape=self.shape)
def compile_HLG(L: Fsa, G: Fsa, H: Fsa, labels_disambig_id_start: int, aux_labels_disambig_id_start: int) -> Fsa: """ Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``. Involves arc sorting, intersection, determinization, removal of disambiguation symbols and adding epsilon self-loops. Args: L: An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols`` and words as ``aux_symbols``. G: An ``Fsa`` that represents the language model (G), i.e. it's an acceptor with words as ``symbols``. H: An ``Fsa`` that represents a specific topology used to convert the network outputs to a sequence of phones. Typically, it's a CTC topology fst, in which when 0 appears on the left side, it represents the blank symbol; when it appears on the right side, it indicates an epsilon. labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the phonetic alphabet. aux_labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the words vocabulary. :return: """ L = k2.arc_sort(L) G = k2.arc_sort(G) logging.info("Intersecting L and G") LG = k2.compose(L, G) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting L*G") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Determinizing L*G") LG = k2.determinize(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting det(L*G)") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Removing disambiguation symbols on L*G") LG.labels[LG.labels >= labels_disambig_id_start] = 0 if isinstance(LG.aux_labels, torch.Tensor): LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0 else: LG.aux_labels.values()[ LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0 logging.info("Removing epsilons") LG = k2.remove_epsilon(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting rm-eps(det(L*G))") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) logging.info("Composing ctc_topo LG") HLG = k2.compose(H, LG, inner_labels='phones') logging.info("Connecting LG") HLG = k2.connect(HLG) logging.info("Arc sorting LG") HLG = k2.arc_sort(HLG) logging.info( f'LG is arc sorted: {(HLG.properties & k2.fsa_properties.ARC_SORTED) != 0}' ) # Attach a new attribute `lm_scores` so that we can recover # the `am_scores` later. # The scores on an arc consists of two parts: # scores = am_scores + lm_scores # NOTE: we assume that both kinds of scores are in log-space. HLG.lm_scores = HLG.scores.clone() return HLG
def run(rank, world_size, args): ''' Args: rank: It is a value between 0 and `world_size-1`, which is passed automatically by `mp.spawn()` in :func:`main`. The node with rank 0 is responsible for saving checkpoint. world_size: Number of GPUs for DDP training. args: The return value of get_parser().parse_args() ''' model_type = args.model_type start_epoch = args.start_epoch num_epochs = args.num_epochs accum_grad = args.accum_grad den_scale = args.den_scale att_rate = args.att_rate use_pruned_intersect = args.use_pruned_intersect fix_random_seed(42) if world_size > 1: setup_dist(rank, world_size, args.master_port) suffix = '' if args.context_window is not None and args.context_window > 0: suffix = f'ac{args.context_window}' giga_subset = f'giga{args.subset}' exp_dir = Path( f'exp-{model_type}-mmi-att-sa-vgg-normlayer-{giga_subset}-{suffix}') setup_logger(f'{exp_dir}/log/log-train-{rank}') if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') else: tb_writer = None logging.info("Loading lexicon and symbol tables") lang_dir = Path('data/lang_nosp') lexicon = Lexicon(lang_dir) device_id = rank device = torch.device('cuda', device_id) if not Path(lang_dir / f'P_{args.subset}.pt').is_file(): logging.debug(f'Loading P from {lang_dir}/P_{args.subset}.fst.txt') with open(lang_dir / f'P_{args.subset}.fst.txt') as f: # P is not an acceptor because there is # a back-off state, whose incoming arcs # have label #0 and aux_label eps. P = k2.Fsa.from_openfst(f.read(), acceptor=False) phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') first_phone_disambig_id = find_first_disambig_symbol( phone_symbol_table) # P.aux_labels is not needed in later computations, so # remove it here. del P.aux_labels # CAUTION(fangjun): The following line is crucial. # Arcs entering the back-off state have label equal to #0. # We have to change it to 0 here. P.labels[P.labels >= first_phone_disambig_id] = 0 P = k2.remove_epsilon(P) P = k2.arc_sort(P) torch.save(P.as_dict(), lang_dir / f'P_{args.subset}.pt') else: logging.debug('Loading pre-compiled P') d = torch.load(lang_dir / f'P_{args.subset}.pt') P = k2.Fsa.from_dict(d) graph_compiler = MmiTrainingGraphCompiler( lexicon=lexicon, P=P, device=device, ) phone_ids = lexicon.phone_symbols() gigaspeech = GigaSpeechAsrDataModule(args) train_dl = gigaspeech.train_dataloaders() valid_dl = gigaspeech.valid_dataloaders() if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) if use_pruned_intersect: logging.info('Use pruned intersect for den_lats') else: logging.info("Don't use pruned intersect for den_lats") logging.info("About to create model") if att_rate != 0.0: num_decoder_layers = 6 else: num_decoder_layers = 0 if model_type == "transformer": model = Transformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers, vgg_frontend=True) elif model_type == "conformer": model = Conformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers, vgg_frontend=True, is_espnet_structure=True) elif model_type == "contextnet": model = ContextNet(num_features=80, num_classes=len(phone_ids) + 1) # +1 for the blank symbol else: raise NotImplementedError("Model of type " + str(model_type) + " is not implemented") if args.torchscript: logging.info('Applying TorchScript to model...') model = torch.jit.script(model) model.to(device) describe(model) if world_size > 1: model = DDP(model, device_ids=[rank]) # Now for the alignment model, if any if args.use_ali_model: ali_model = TdnnLstm1b( num_features=80, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4) ali_model_fname = Path( f'exp-lstm-adam-ctc-musan/epoch-{args.ali_model_epoch}.pt') assert ali_model_fname.is_file(), \ f'ali model filename {ali_model_fname} does not exist!' ali_model.load_state_dict( torch.load(ali_model_fname, map_location='cpu')['state_dict']) ali_model.to(device) ali_model.eval() ali_model.requires_grad_(False) logging.info(f'Use ali_model: {ali_model_fname}') else: ali_model = None logging.info('No ali_model') optimizer = Noam(model.parameters(), model_size=args.attention_dim, factor=args.lr_factor, warm_step=args.warm_step, weight_decay=args.weight_decay) scaler = GradScaler(enabled=args.amp) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer, scaler=scaler) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info( f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}" ) for epoch in range(start_epoch, num_epochs): train_dl.sampler.set_epoch(epoch) curr_learning_rate = optimizer._rate if tb_writer is not None: tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch( dataloader=train_dl, valid_dataloader=valid_dl, model=model, ali_model=ali_model, device=device, graph_compiler=graph_compiler, use_pruned_intersect=use_pruned_intersect, optimizer=optimizer, accum_grad=accum_grad, den_scale=den_scale, att_rate=att_rate, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, world_size=world_size, scaler=scaler) # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, optimizer=None, scheduler=None, scaler=None, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train, local_rank=rank, torchscript=args.torchscript_epoch != -1 and epoch >= args.torchscript_epoch) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch, local_rank=rank) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, optimizer=optimizer, scheduler=None, scaler=scaler, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train, local_rank=rank, torchscript=args.torchscript_epoch != -1 and epoch >= args.torchscript_epoch) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch, local_rank=rank) logging.warning('Done') if world_size > 1: torch.distributed.barrier() cleanup_dist()
def main(): args = get_parser().parse_args() print('World size:', args.world_size, 'Rank:', args.local_rank) setup_dist(rank=args.local_rank, world_size=args.world_size, master_port=args.master_port) fix_random_seed(42) start_epoch = 0 num_epochs = 10 use_adam = True exp_dir = f'exp-lstm-adam-mmi-bigram-musan' setup_logger('{}/log/log-train'.format(exp_dir)) tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') lexicon = Lexicon(lang_dir) device_id = args.local_rank device = torch.device('cuda', device_id) phone_ids = lexicon.phone_symbols() if not Path(lang_dir / 'P.pt').is_file(): logging.debug(f'Loading P from {lang_dir}/P.fst.txt') with open(lang_dir / 'P.fst.txt') as f: # P is not an acceptor because there is # a back-off state, whose incoming arcs # have label #0 and aux_label eps. P = k2.Fsa.from_openfst(f.read(), acceptor=False) phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') first_phone_disambig_id = find_first_disambig_symbol( phone_symbol_table) # P.aux_labels is not needed in later computations, so # remove it here. del P.aux_labels # CAUTION(fangjun): The following line is crucial. # Arcs entering the back-off state have label equal to #0. # We have to change it to 0 here. P.labels[P.labels >= first_phone_disambig_id] = 0 P = k2.remove_epsilon(P) P = k2.arc_sort(P) torch.save(P.as_dict(), lang_dir / 'P.pt') else: logging.debug('Loading pre-compiled P') d = torch.load(lang_dir / 'P.pt') P = k2.Fsa.from_dict(d) graph_compiler = MmiTrainingGraphCompiler( lexicon=lexicon, P=P, device=device, ) # load dataset feature_dir = Path('exp/data') logging.info("About to get train cuts") cuts_train = CutSet.from_json(feature_dir / 'cuts_train.json.gz') logging.info("About to get dev cuts") cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev.json.gz') logging.info("About to get Musan cuts") cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") train = K2SpeechRecognitionDataset(cuts_train, cut_transforms=[ CutConcatenate(), CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20)) ]) train_sampler = SingleCutSampler( cuts_train, max_frames=40000, shuffle=True, ) logging.info("About to create train dataloader") train_dl = torch.utils.data.DataLoader(train, sampler=train_sampler, batch_size=None, num_workers=4) logging.info("About to create dev dataset") validate = K2SpeechRecognitionDataset(cuts_dev) valid_sampler = SingleCutSampler(cuts_dev, max_frames=12000) logging.info("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader(validate, sampler=valid_sampler, batch_size=None, num_workers=1) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) logging.info("About to create model") device_id = 0 device = torch.device('cuda', device_id) model = TdnnLstm1b( num_features=40, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=3) model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True) model.to(device) describe(model) if use_adam: learning_rate = 1e-3 weight_decay = 5e-4 optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) # Equivalent to the following in the epoch loop: # if epoch > 6: # curr_learning_rate *= 0.8 lr_scheduler = optim.lr_scheduler.LambdaLR( optimizer, lambda ep: 1.0 if ep < 7 else 0.8**(ep - 6)) else: learning_rate = 5e-5 weight_decay = 1e-5 momentum = 0.9 lr_schedule_gamma = 0.7 optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) lr_scheduler = optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=lr_schedule_gamma) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer, scheduler=lr_scheduler) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info( f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}" ) for epoch in range(start_epoch, num_epochs): train_sampler.set_epoch(epoch) # LR scheduler can hold multiple learning rates for multiple parameter groups; # For now we report just the first LR which we assume concerns most of the parameters. curr_learning_rate = lr_scheduler.get_last_lr()[0] tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch( dataloader=train_dl, valid_dataloader=valid_dl, model=model, device=device, graph_compiler=graph_compiler, optimizer=optimizer, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, ) lr_scheduler.step() # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, model=model, optimizer=None, scheduler=None, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, model=model, optimizer=optimizer, scheduler=lr_scheduler, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) logging.warning('Done')
def test_autograd(self): if not torch.cuda.is_available(): return if not k2.with_cuda: return devices = [torch.device('cuda', 0)] if torch.cuda.device_count() > 1: torch.cuda.set_device(1) devices.append(torch.device('cuda', 1)) s = ''' 0 1 0 0.1 0 1 1 0.2 1 2 -1 0.3 2 ''' for device in devices: src = k2.Fsa.from_str(s).to(device).requires_grad_(True) scores_copy = src.scores.detach().clone().requires_grad_(True) src.attr1 = "hello" src.attr2 = "k2" float_attr = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, requires_grad=True, device=device) src.float_attr = float_attr.detach().clone().requires_grad_(True) src.int_attr = torch.tensor([1, 2, 3], dtype=torch.int32, device=device) src.ragged_attr = k2.RaggedTensor([[10, 20], [30, 40, 50], [60, 70]]).to(device) dest = k2.remove_epsilon(src) # arc map is [[1] [0 2] [2]] assert dest.attr1 == src.attr1 assert dest.attr2 == src.attr2 expected_int_attr = k2.RaggedTensor([[2], [1, 3], [3]]).to(device) assert dest.int_attr == expected_int_attr expected_ragged_attr = k2.RaggedTensor([[30, 40, 50], [10, 20, 60, 70], [60, 70]]).to(device) assert dest.ragged_attr == expected_ragged_attr expected_float_attr = torch.empty_like(dest.float_attr) expected_float_attr[0] = float_attr[1] expected_float_attr[1] = float_attr[0] + float_attr[2] expected_float_attr[2] = float_attr[2] assert torch.all(torch.eq(dest.float_attr, expected_float_attr)) expected_scores = torch.empty_like(dest.scores) expected_scores[0] = scores_copy[1] expected_scores[1] = scores_copy[0] + scores_copy[2] expected_scores[2] = scores_copy[2] assert torch.all(torch.eq(dest.scores, expected_scores)) scale = torch.tensor([10, 20, 30]).to(float_attr) (dest.float_attr * scale).sum().backward() (expected_float_attr * scale).sum().backward() assert torch.all(torch.eq(src.float_attr.grad, float_attr.grad)) (dest.scores * scale).sum().backward() (expected_scores * scale).sum().backward() assert torch.all(torch.eq(src.scores.grad, scores_copy.grad))
def compile_LG(L: Fsa, G: Fsa, ctc_topo: Fsa, labels_disambig_id_start: int, aux_labels_disambig_id_start: int) -> Fsa: """ Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``. Involves arc sorting, intersection, determinization, removal of disambiguation symbols and adding epsilon self-loops. Args: L: An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols`` and words as ``aux_symbols``. G: An ``Fsa`` that represents the language model (G), i.e. it's an acceptor with words as ``symbols``. ctc_topo: CTC topology fst, in which when 0 appears on the left side, it represents the blank symbol; when it appears on the right side, it indicates an epsilon. labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the phonetic alphabet. aux_labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the words vocabulary. :return: """ L = k2.arc_sort(L) G = k2.arc_sort(G) logging.info("Intersecting L and G") LG = k2.compose(L, G) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting L*G") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Determinizing L*G") LG = k2.determinize(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting det(L*G)") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Removing disambiguation symbols on L*G") LG.labels[LG.labels >= labels_disambig_id_start] = 0 if isinstance(LG.aux_labels, torch.Tensor): LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0 else: LG.aux_labels.values()[ LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0 logging.info("Removing epsilons") LG = k2.remove_epsilon(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting rm-eps(det(L*G))") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) logging.info("Composing ctc_topo LG") LG = k2.compose(ctc_topo, LG, inner_labels='phones') logging.info("Connecting LG") LG = k2.connect(LG) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) logging.info( f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}' ) return LG
def main(): args = get_parser().parse_args() print('World size:', args.world_size, 'Rank:', args.local_rank) setup_dist(rank=args.local_rank, world_size=args.world_size, master_port=args.master_port) fix_random_seed(42) start_epoch = 0 num_epochs = 10 use_adam = True exp_dir = f'exp-lstm-adam-mmi-bigram-musan-dist' setup_logger('{}/log/log-train'.format(exp_dir), use_console=args.local_rank == 0) tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') if args.local_rank == 0 else None # load L, G, symbol_table lang_dir = Path('data/lang_nosp') lexicon = Lexicon(lang_dir) device_id = args.local_rank device = torch.device('cuda', device_id) phone_ids = lexicon.phone_symbols() if not Path(lang_dir / 'P.pt').is_file(): logging.debug(f'Loading P from {lang_dir}/P.fst.txt') with open(lang_dir / 'P.fst.txt') as f: # P is not an acceptor because there is # a back-off state, whose incoming arcs # have label #0 and aux_label eps. P = k2.Fsa.from_openfst(f.read(), acceptor=False) phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table) # P.aux_labels is not needed in later computations, so # remove it here. del P.aux_labels # CAUTION(fangjun): The following line is crucial. # Arcs entering the back-off state have label equal to #0. # We have to change it to 0 here. P.labels[P.labels >= first_phone_disambig_id] = 0 P = k2.remove_epsilon(P) P = k2.arc_sort(P) torch.save(P.as_dict(), lang_dir / 'P.pt') else: logging.debug('Loading pre-compiled P') d = torch.load(lang_dir / 'P.pt') P = k2.Fsa.from_dict(d) graph_compiler = MmiTrainingGraphCompiler( lexicon=lexicon, P=P, device=device, ) # load dataset feature_dir = Path('exp/data') logging.info("About to get train cuts") cuts_train = CutSet.from_json(feature_dir / 'cuts_train-clean-100.json.gz') logging.info("About to get dev cuts") cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev-clean.json.gz') logging.info("About to get Musan cuts") cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] if not args.bucketing_sampler: # We don't mix concatenating the cuts and bucketing # Here we insert concatenation before mixing so that the # noises from Musan are mixed onto almost-zero-energy # padding frames. transforms = [CutConcatenate(duration_factor=1)] + transforms train = K2SpeechRecognitionDataset(cuts_train, cut_transforms=transforms) if args.bucketing_sampler: logging.info('Using BucketingSampler.') train_sampler = BucketingSampler( cuts_train, max_frames=40000, shuffle=True, num_buckets=30 ) else: logging.info('Using regular sampler with cut concatenation.') train_sampler = SingleCutSampler( cuts_train, max_frames=30000, shuffle=True, ) logging.info("About to create train dataloader") train_dl = torch.utils.data.DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=4 ) logging.info("About to create dev dataset") validate = K2SpeechRecognitionDataset(cuts_dev) # Note: we explicitly set world_size to 1 to disable the auto-detection of # distributed training inside the sampler. This way, every GPU will # perform the computation on the full dev set. It is a bit wasteful, # but unfortunately loss aggregation between multiple processes with # torch.distributed.all_reduce() tends to hang indefinitely inside # NCCL after ~3000 steps. With the current approach, we can still report # the loss on the full validation set. valid_sampler = SingleCutSampler(cuts_dev, max_frames=90000, world_size=1, rank=0) logging.info("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader( validate, sampler=valid_sampler, batch_size=None, num_workers=1 ) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) logging.info("About to create model") model = TdnnLstm1b(num_features=80, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=3) model.to(device) describe(model) if use_adam: learning_rate = 1e-3 weight_decay = 5e-4 optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) # Equivalent to the following in the epoch loop: # if epoch > 6: # curr_learning_rate *= 0.8 lr_scheduler = optim.lr_scheduler.LambdaLR( optimizer, lambda ep: 1.0 if ep < 7 else 0.8 ** (ep - 6) ) else: learning_rate = 5e-5 weight_decay = 1e-5 momentum = 0.9 lr_schedule_gamma = 0.7 optimizer = optim.SGD( model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay ) lr_scheduler = optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=lr_schedule_gamma ) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer, scheduler=lr_scheduler) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info(f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}") if args.world_size > 1: logging.info('Using DistributedDataParallel in training. ' 'The reported loss, num_frames, etc. for training steps include ' 'only the batches seen in the master process (the actual loss ' 'includes batches from all GPUs, and the actual num_frames is ' f'approx. {args.world_size}x larger.') # For now do not sync BatchNorm across GPUs due to NCCL hanging in all_gather... # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) for epoch in range(start_epoch, num_epochs): train_sampler.set_epoch(epoch) # LR scheduler can hold multiple learning rates for multiple parameter groups; # For now we report just the first LR which we assume concerns most of the parameters. curr_learning_rate = lr_scheduler.get_last_lr()[0] if tb_writer is not None: tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) logging.info('epoch {}, learning rate {}'.format(epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch( dataloader=train_dl, valid_dataloader=valid_dl, model=model, device=device, graph_compiler=graph_compiler, optimizer=optimizer, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, ) lr_scheduler.step() # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, model=model, optimizer=None, scheduler=None, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, local_rank=args.local_rank, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, model=model, optimizer=optimizer, scheduler=lr_scheduler, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, local_rank=args.local_rank, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) logging.warning('Done') cleanup_dist()