def main(): parser = argparse.ArgumentParser() parser.add_argument("-exp_dir") parser.add_argument("-dataPath", default='', type=str, help="path of data files") parser.add_argument("-train_config") parser.add_argument("-data_config") parser.add_argument("-lr", default=0.0001, type=float, help="Override the LR in the config") parser.add_argument("-batch_size", default=32, type=int, help="Override the batch size in the config") parser.add_argument("-data_loader_threads", default=0, type=int, help="number of workers for data loading") parser.add_argument("-max_grad_norm", default=5, type=float, help="max_grad_norm for gradient clipping") parser.add_argument("-sweep_size", default=200, type=float, help="process n hours of data per sweep (default:200)") parser.add_argument("-num_epochs", default=1, type=int, help="number of training epochs (default:1)") parser.add_argument("-global_mvn", default=False, type=bool, help="if apply global mean and variance normalization") parser.add_argument( "-resume_from_model", type=str, help="the model from which you want to resume training") parser.add_argument("-dropout", type=float, help="set the dropout ratio") parser.add_argument("-aneal_lr_epoch", default=2, type=int, help="start to aneal the learning rate from this epoch" ) # aneal -> anneal? parser.add_argument("-aneal_lr_ratio", default=0.5, type=float, help="the ratio to aneal the learning rate") parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('-hvd', default=False, type=bool, help="whether to use horovod for training") args = parser.parse_args() with open(args.train_config) as f: config = yaml.safe_load(f) config["sweep_size"] = args.sweep_size with open(args.data_config) as f: data = yaml.safe_load(f) config["source_paths"] = [j for i, j in data['clean_source'].items()] if 'dir_noise' in data: config["dir_noise_paths"] = [ j for i, j in data['dir_noise'].items() ] if 'rir' in data: config["rir_paths"] = [j for i, j in data['rir'].items()] config['data_path'] = args.dataPath print("Experiment starts with config {}".format( json.dumps(config, sort_keys=True, indent=4))) # Initialize Horovod if args.hvd: import horovod.torch as hvd hvd.init() th.cuda.set_device(hvd.local_rank()) print("Run experiments with world size {}".format(hvd.size())) if not os.path.isdir(args.exp_dir): os.makedirs(args.exp_dir) trainset = SpeechDataset(config) train_dataloader = ChunkDataloader(trainset, batch_size=args.batch_size, distributed=args.multi_gpu, num_workers=args.data_loader_threads) if args.global_mvn: transform = GlobalMeanVarianceNormalization() print("Estimating global mean and variance of feature vectors...") transform.learn_mean_and_variance_from_train_loader( trainset, trainset.stream_idx_for_transform, n_sample_to_use=2000) trainset.transform = transform print("Global mean and variance transform trained successfully!") with open(args.exp_dir + "/transform.pkl", 'wb') as f: pickle.dump(transform, f, pickle.HIGHEST_PROTOCOL) print("Data loader set up successfully!") print("Number of minibatches: {}".format(len(train_dataloader))) # ceate model model_config = config["model_config"] lstm = LSTMStack(model_config["feat_dim"], model_config["hidden_size"], model_config["num_layers"], model_config["dropout"], True) model = NnetAM(lstm, model_config["hidden_size"] * 2, model_config["label_size"]) # Start training th.backends.cudnn.enabled = True if th.cuda.is_available(): model.cuda() # optimizer optimizer = th.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True) if args.hvd: # Broadcast parameters and opterimizer state from rank 0 to all other processes. hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) # Add Horovod Distributed Optimizer optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters()) # criterion criterion = nn.CrossEntropyLoss(ignore_index=-100) start_epoch = 0 if args.resume_from_model: assert os.path.isfile(args.resume_from_model ), "ERROR: model file {} does not exit!".format( args.resume_from_model) checkpoint = th.load(args.resume_from_model) state_dict = checkpoint['model'] start_epoch = checkpoint['epoch'] model.load_state_dict(state_dict) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' ".format(args.resume_from_model)) model.train() for epoch in range(start_epoch, args.num_epochs): # aneal learning rate if epoch > args.aneal_lr_epoch: for param_group in optimizer.param_groups: param_group['lr'] *= args.aneal_lr_ratio run_train_epoch(model, optimizer, criterion, train_dataloader, epoch, args) # save model if not args.hvd or hvd.rank() == 0: checkpoint = {} checkpoint['model'] = model.state_dict() checkpoint['optimizer'] = optimizer.state_dict() checkpoint['epoch'] = epoch output_file = args.exp_dir + '/model.' + str(epoch) + '.tar' th.save(checkpoint, output_file)
def main(): parser = argparse.ArgumentParser() parser.add_argument("-config") parser.add_argument("-data", help="data yaml file") parser.add_argument("-data_path", default='', type=str, help="path of data files") parser.add_argument("-seed_model", help="the seed nerual network model") parser.add_argument("-exp_dir", help="the directory to save the outputs") parser.add_argument("-transform", help="feature transformation matrix or mvn statistics") parser.add_argument("-criterion", type=str, choices=["mmi", "mpfe", "smbr"], help="set the sequence training crtierion") parser.add_argument( "-trans_model", help="the HMM transistion model, used for lattice generation") parser.add_argument( "-prior_path", help="the prior for decoder, usually named as final.occs in kaldi setup" ) parser.add_argument( "-den_dir", help="the decoding graph directory to find HCLG and words.txt files") parser.add_argument("-lr", type=float, help="set the learning rate") parser.add_argument("-ce_ratio", default=0.1, type=float, help="the ratio for ce regularization") parser.add_argument("-momentum", default=0, type=float, help="set the momentum") parser.add_argument("-batch_size", default=32, type=int, help="Override the batch size in the config") parser.add_argument("-data_loader_threads", default=0, type=int, help="number of workers for data loading") parser.add_argument("-max_grad_norm", default=5, type=float, help="max_grad_norm for gradient clipping") parser.add_argument("-sweep_size", default=100, type=float, help="process n hours of data per sweep (default:60)") parser.add_argument("-num_epochs", default=1, type=int, help="number of training epochs (default:1)") parser.add_argument('-print_freq', default=10, type=int, metavar='N', help='print frequency (default: 10)') parser.add_argument('-save_freq', default=1000, type=int, metavar='N', help='save model frequency (default: 1000)') args = parser.parse_args() with open(args.config) as f: config = yaml.safe_load(f) config['data_path'] = args.data_path config["sweep_size"] = args.sweep_size print("pytorch version:{}".format(th.__version__)) with open(args.data) as f: data = yaml.safe_load(f) config["source_paths"] = [j for i, j in data['clean_source'].items()] print("Experiment starts with config {}".format( json.dumps(config, sort_keys=True, indent=4))) # Initialize Horovod hvd.init() th.cuda.set_device(hvd.local_rank()) print("Run experiments with world size {}".format(hvd.size())) dataset = SpeechDataset(config) transform = None if args.transform is not None and os.path.isfile(args.transform): with open(args.transform, 'rb') as f: transform = pickle.load(f) dataset.transform = transform train_dataloader = SeqDataloader(dataset, batch_size=args.batch_size, num_workers=args.data_loader_threads, distributed=True, test_only=False) print("Data loader set up successfully!") print("Number of minibatches: {}".format(len(train_dataloader))) if not os.path.isdir(args.exp_dir): os.makedirs(args.exp_dir) # ceate model model_config = config["model_config"] lstm = LSTMStack(model_config["feat_dim"], model_config["hidden_size"], model_config["num_layers"], model_config["dropout"], True) model = NnetAM(lstm, model_config["hidden_size"] * 2, model_config["label_size"]) model.cuda() # setup the optimizer optimizer = th.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) # Broadcast parameters and opterimizer state from rank 0 to all other processes. hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) # Add Horovod Distributed Optimizer optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters()) if os.path.isfile(args.seed_model): checkpoint = th.load(args.seed_model) state_dict = checkpoint['model'] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove 'module.' of dataparallel new_state_dict[name] = v model.load_state_dict(new_state_dict) print("=> loaded checkpoint '{}' ".format(args.seed_model)) else: sys.stderr.write('ERROR: The model file %s does not exist!\n' % (model_file)) sys.exit(0) HCLG = args.den_dir + "/HCLG.fst" words_txt = args.den_dir + "/words.txt" silence_phones = args.den_dir + "/phones/silence.csl" if not os.path.isfile(HCLG): sys.stderr.write('ERROR: The HCLG file %s does not exist!\n' % (HCLG)) sys.exit(0) if not os.path.isfile(words_txt): sys.stderr.write('ERROR: The words.txt file %s does not exist!\n' % (words_txt)) sys.exit(0) if not os.path.isfile(silence_phones): sys.stderr.write('ERROR: The silence phone file %s does not exist!\n' % (silence_phones)) sys.exit(0) with open(silence_phones) as f: silence_ids = [int(i) for i in f.readline().strip().split(':')] f.close() if os.path.isfile(args.trans_model): trans_model = kaldi_hmm.TransitionModel() with kaldi_util.io.xopen(args.trans_model) as ki: trans_model.read(ki.stream(), ki.binary) else: sys.stderr.write('ERROR: The trans_model %s does not exist!\n' % (args.trans_model)) sys.exit(0) # now we can setup the decoder decoder_opts = LatticeFasterDecoderOptions() decoder_opts.beam = config["decoder_config"]["beam"] decoder_opts.lattice_beam = config["decoder_config"]["lattice_beam"] decoder_opts.max_active = config["decoder_config"]["max_active"] acoustic_scale = config["decoder_config"]["acoustic_scale"] decoder_opts.determinize_lattice = False #To produce raw state-level lattice instead of compact lattice asr_decoder = MappedLatticeFasterRecognizer.from_files( args.trans_model, HCLG, words_txt, acoustic_scale=acoustic_scale, decoder_opts=decoder_opts) prior = kaldi_util.io.read_matrix(args.prior_path).numpy() log_prior = th.tensor(np.log(prior[0] / np.sum(prior[0])), dtype=th.float) model.train() for epoch in range(args.num_epochs): run_train_epoch(model, optimizer, log_prior.cuda(), train_dataloader, epoch, asr_decoder, trans_model, silence_ids, args) # save model if hvd.rank() == 0: checkpoint = {} checkpoint['model'] = model.state_dict() checkpoint['optimizer'] = optimizer.state_dict() checkpoint['epoch'] = epoch output_file = args.exp_dir + '/model.se.' + str(epoch) + '.tar' th.save(checkpoint, output_file)
def main(): parser = argparse.ArgumentParser() parser.add_argument("-config") parser.add_argument("-model_path") parser.add_argument("-data_path") parser.add_argument("-prior_path", help="the path to load the final.occs file") parser.add_argument("-out_file", help="write out the log-probs to this file") parser.add_argument("-transform", help="feature transformation matrix or mvn statistics") parser.add_argument( "-trans_model", help="the HMM transistion model, used for lattice generation") parser.add_argument("-graph_dir", help="the decoding graph directory") parser.add_argument("-batch_size", default=32, type=int, help="Override the batch size in the config") parser.add_argument("-sweep_size", default=200, type=float, help="process n hours of data per sweep (default:60)") parser.add_argument("-data_loader_threads", default=4, type=int, help="number of workers for data loading") args = parser.parse_args() with open(args.config) as f: config = yaml.safe_load(f) config["sweep_size"] = args.sweep_size config["source_paths"] = list() data_config = dict() data_config["type"] = "Eval" data_config["wav"] = args.data_path config["source_paths"].append(data_config) print("job starts with config {}".format( json.dumps(config, sort_keys=True, indent=4))) transform = None if args.transform is not None and os.path.isfile(args.transform): with open(args.transform, 'rb') as f: transform = pickle.load(f) dataset = SpeechDataset(config) #data = trainset.__getitem__(0) test_dataloader = SeqDataloader(dataset, batch_size=args.batch_size, test_only=True, global_mvn=True, transform=transform) print("Data loader set up successfully!") print("Number of minibatches: {}".format(len(test_dataloader))) # ceate model model_config = config["model_config"] lstm = LSTMStack(model_config["feat_dim"], model_config["hidden_size"], model_config["num_layers"], model_config["dropout"], True) model = NnetAM(lstm, model_config["hidden_size"] * 2, model_config["label_size"]) device = th.device("cuda" if th.cuda.is_available() else "cpu") model.cuda() assert os.path.isfile( args.model_path), "ERROR: model file {} does not exit!".format( args.model_path) checkpoint = th.load(args.model_path, map_location='cuda:0') state_dict = checkpoint['model'] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): header = k[:7] name = k[7:] # remove 'module.' of dataparallel new_state_dict[name] = v if header == "module.": model.load_state_dict(new_state_dict) else: model.load_state_dict(state_dict) print("=> loaded checkpoint '{}' ".format(args.model_path)) HCLG = args.graph_dir + "/HCLG.fst" words_txt = args.graph_dir + "/words.txt" if not os.path.isfile(HCLG): sys.stderr.write('ERROR: The HCLG file %s does not exist!\n' % (HCLG)) sys.exit(0) if not os.path.isfile(words_txt): sys.stderr.write('ERROR: The words.txt file %s does not exist!\n' % (words_txt)) sys.exit(0) if os.path.isfile(args.trans_model): trans_model = kaldi_hmm.TransitionModel() with kaldi_util.io.xopen(args.trans_model) as ki: trans_model.read(ki.stream(), ki.binary) else: sys.stderr.write('ERROR: The trans_model %s does not exist!\n' % (args.trans_model)) sys.exit(0) prior = read_matrix(args.prior_path).numpy() log_prior = th.tensor(np.log(prior[0] / np.sum(prior[0])), dtype=th.float) # now we can setup the decoder decoder_opts = LatticeFasterDecoderOptions() decoder_opts.beam = config["decoder_config"]["beam"] decoder_opts.lattice_beam = config["decoder_config"]["lattice_beam"] decoder_opts.max_active = config["decoder_config"]["max_active"] acoustic_scale = config["decoder_config"]["acoustic_scale"] decoder_opts.determinize_lattice = True #To produce compact lattice asr_decoder = MappedLatticeFasterRecognizer.from_files( args.trans_model, HCLG, words_txt, acoustic_scale=acoustic_scale, decoder_opts=decoder_opts) model.eval() with th.no_grad(): with kaldi_util.table.CompactLatticeWriter("ark:" + args.out_file) as lat_out: for data in test_dataloader: feat = data["x"] num_frs = data["num_frs"] utt_ids = data["utt_ids"] x = feat.to(th.float32) x = x.cuda() prediction = model(x) for j in range(len(num_frs)): loglikes = prediction[j, :, :].data.cpu() loglikes_j = loglikes[:num_frs[j], :] loglikes_j = loglikes_j - log_prior decoder_out = asr_decoder.decode( kaldi_matrix.Matrix(loglikes_j.numpy())) key = utt_ids[j][0] print(key, decoder_out["text"]) print("Log-like per-frame for utterance {} is {}".format( key, decoder_out["likelihood"] / num_frs[j])) # save lattice lat_out[key] = decoder_out["lattice"]
def main(): #if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("-config") parser.add_argument("-model_path") parser.add_argument("-data") parser.add_argument("-data_path", default='', type=str, help="path of data files") parser.add_argument("-prior_path", default=None, help="the path to load the final.occs file") parser.add_argument("-transform", help="feature transformation matrix or mvn statistics") parser.add_argument("-out_file", help="write out the log-probs to this file") parser.add_argument("-batch_size", default=32, type=int, help="Override the batch size in the config") parser.add_argument("-sweep_size", default=200, type=float, help="process n hours of data per sweep (default:60)") parser.add_argument("-frame_subsampling_factor", default=1, type=int, help="the factor to subsample the features") parser.add_argument("-data_loader_threads", default=4, type=int, help="number of workers for data loading") args = parser.parse_args() with open(args.config) as f: config = yaml.safe_load(f) config["sweep_size"] = args.sweep_size config["source_paths"] = list() data_config = dict() data_config["type"] = "Eval" data_config["wav"] = args.data config["source_paths"].append(data_config) config["data_path"] = args.data_path print("job starts with config {}".format( json.dumps(config, sort_keys=True, indent=4))) transform = None if args.transform is not None and os.path.isfile(args.transform): with open(args.transform, 'rb') as f: transform = pickle.load(f) dataset = SpeechDataset(config) print(transform) test_dataloader = SeqDataloader(dataset, batch_size=args.batch_size, test_only=True, global_mvn=True, transform=transform) print("Data loader set up successfully!") print("Number of minibatches: {}".format(len(test_dataloader))) # ceate model model_config = config["model_config"] lstm = LSTMStack(model_config["feat_dim"], model_config["hidden_size"], model_config["num_layers"], model_config["dropout"], True) model = NnetAM(lstm, model_config["hidden_size"] * 2, model_config["label_size"]) device = th.device("cuda:1" if th.cuda.is_available() else "cpu") model.cuda() assert os.path.isfile( args.model_path), "ERROR: model file {} does not exit!".format( args.model_path) checkpoint = th.load(args.model_path, map_location='cuda:0') state_dict = checkpoint['model'] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): header = k[:7] name = k[7:] # remove 'module.' of dataparallel new_state_dict[name] = v if header == "module.": model.load_state_dict(new_state_dict) else: model.load_state_dict(state_dict) print("=> loaded checkpoint '{}' ".format(args.model_path)) log_prior = None if (args.prior_path): prior = read_matrix(args.prior_path).numpy() log_prior = th.tensor(np.log(prior[0] / np.sum(prior[0])), dtype=th.float) model.eval() with th.no_grad(): with MatrixWriter("ark:" + args.out_file) as llout: for i, data in enumerate(test_dataloader): feat = data["x"] num_frs = data["num_frs"] utt_ids = data["utt_ids"] x = feat.to(th.float32) if (args.frame_subsampling_factor > 1): x = x.unfold(1, 1, args.frame_subsampling_factor).squeeze(-1) x = x.cuda() prediction = model(x) # save only unpadded part for each utt in batch for j in range(len(num_frs)): loglikes = prediction[j, :, :].data.cpu() loglikes_j = loglikes[:num_frs[j], :] if (log_prior): loglikes_j = loglikes_j - log_prior llout[utt_ids[j][0]] = loglikes_j print("Process batch [{}/{}]".format(i + 1, len(test_dataloader)))
def main(): parser = argparse.ArgumentParser() parser.add_argument("-config") parser.add_argument("-data", help="data yaml file") parser.add_argument("-dataPath", default='', type=str, help="path of data files") parser.add_argument("-seed_model", default='', help="the seed nerual network model") parser.add_argument("-exp_dir", help="the directory to save the outputs") parser.add_argument("-transform", help="feature transformation matrix or mvn statistics") parser.add_argument( "-ali_dir", help="the directory to load trans_model and tree used for alignments") parser.add_argument("-lang_dir", help="the lexicon directory to load L.fst") parser.add_argument( "-chain_dir", help= "the directory to load trans_model, tree and den.fst for chain model") parser.add_argument("-lr", type=float, help="set the learning rate") parser.add_argument("-xent_regularize", default=0, type=float, help="cross-entropy regularization weight") parser.add_argument("-momentum", default=0, type=float, help="set the momentum") parser.add_argument("-weight_decay", default=1e-4, type=float, help="set the L2 regularization weight") parser.add_argument("-batch_size", default=32, type=int, help="Override the batch size in the config") parser.add_argument("-data_loader_threads", default=0, type=int, help="number of workers for data loading") parser.add_argument("-max_grad_norm", default=5, type=float, help="max_grad_norm for gradient clipping") parser.add_argument("-sweep_size", default=100, type=float, help="process n hours of data per sweep (default:100)") parser.add_argument("-num_epochs", default=1, type=int, help="number of training epochs (default:1)") parser.add_argument( "-anneal_lr_epoch", default=2, type=int, help="start to anneal the learning rate from this epoch") parser.add_argument("-anneal_lr_ratio", default=0.5, type=float, help="the ratio to anneal the learning rate ratio") parser.add_argument('-print_freq', default=10, type=int, metavar='N', help='print frequency (default: 10)') parser.add_argument('-save_freq', default=1000, type=int, metavar='N', help='save model frequency (default: 1000)') args = parser.parse_args() with open(args.config) as f: config = yaml.safe_load(f) config["sweep_size"] = args.sweep_size print("pytorch version:{}".format(th.__version__)) with open(args.data) as f: data = yaml.safe_load(f) config["source_paths"] = [j for i, j in data['clean_source'].items()] if 'dir_noise' in data: config["dir_noise_paths"] = [ j for i, j in data['dir_noise'].items() ] if 'rir' in data: config["rir_paths"] = [j for i, j in data['rir'].items()] config['data_path'] = args.dataPath print("Experiment starts with config {}".format( json.dumps(config, sort_keys=True, indent=4))) # Initialize Horovod hvd.init() th.cuda.set_device(hvd.local_rank()) print("Run experiments with world size {}".format(hvd.size())) dataset = SpeechDataset(config) transform = None if args.transform is not None and os.path.isfile(args.transform): with open(args.transform, 'rb') as f: transform = pickle.load(f) dataset.transform = transform train_dataloader = SeqDataloader(dataset, batch_size=args.batch_size, num_workers=args.data_loader_threads, distributed=True, test_only=False) print("Data loader set up successfully!") print("Number of minibatches: {}".format(len(train_dataloader))) if not os.path.isdir(args.exp_dir): os.makedirs(args.exp_dir) # ceate model model_config = config["model_config"] lstm = LSTMStack(model_config["feat_dim"], model_config["hidden_size"], model_config["num_layers"], model_config["dropout"], True) model = NnetAM(lstm, model_config["hidden_size"] * 2, model_config["label_size"]) model.cuda() # setup the optimizer optimizer = th.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True) # Broadcast parameters and opterimizer state from rank 0 to all other processes. hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) # Add Horovod Distributed Optimizer optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters()) if os.path.isfile(args.seed_model): checkpoint = th.load(args.seed_model) state_dict = checkpoint['model'] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): header = k[:7] name = k[7:] # remove 'module.' of dataparallel new_state_dict[name] = v if header == "module.": model.load_state_dict(new_state_dict) else: model.load_state_dict(state_dict) print("=> loaded checkpoint '{}' ".format(args.seed_model)) ali_model = args.ali_dir + "/final.mdl" ali_tree = args.ali_dir + "/tree" L_fst = args.lang_dir + "/L.fst" disambig = args.lang_dir + "/phones/disambig.int" den_fst = kaldi_fst.StdVectorFst.read(args.chain_dir + "/den.fst") chain_model_path = args.chain_dir + "/0.trans_mdl" chain_tree_path = args.chain_dir + "/tree" if os.path.isfile(chain_model_path): chain_trans_model = kaldi_hmm.TransitionModel() with kaldi_util.io.xopen(chain_model_path) as ki: chain_trans_model.read(ki.stream(), ki.binary) else: sys.stderr.write('ERROR: The trans_model %s does not exist!\n' % (trans_model)) sys.exit(0) chain_tree = kaldi_tree.ContextDependency() with kaldi_util.io.xopen(chain_tree_path) as ki: chain_tree.read(ki.stream(), ki.binary) # chain supervision options supervision_opts = kaldi_chain.SupervisionOptions() supervision_opts.convert_to_pdfs = True supervision_opts.frame_subsampling_factor = 3 supervision_opts.left_tolerance = 5 supervision_opts.right_tolerance = 5 # chain training options chain_opts = kaldi_chain.ChainTrainingOptions() chain_opts.leaky_hmm_coefficient = 1e-4 chain_opts.xent_regularize = args.xent_regularize # setup the aligner aligner = kaldi_align.MappedAligner.from_files(ali_model, ali_tree, L_fst, None, disambig, None, beam=10, transition_scale=1.0, self_loop_scale=0.1, acoustic_scale=0.1) den_graph = kaldi_chain.DenominatorGraph(den_fst, model_config["label_size"]) #encoder_layer = nn.TransformerEncoderLayer(512, 8) #print(encoder_layer) model.train() for epoch in range(args.num_epochs): # anneal learning rate if epoch > args.anneal_lr_epoch: for param_group in optimizer.param_groups: param_group['lr'] *= args.anneal_lr_ratio run_train_epoch(model, optimizer, train_dataloader, epoch, chain_trans_model, chain_tree, supervision_opts, aligner, den_graph, chain_opts, args) # save model if hvd.rank() == 0: checkpoint = {} checkpoint['model'] = model.state_dict() checkpoint['optimizer'] = optimizer.state_dict() checkpoint['epoch'] = epoch output_file = args.exp_dir + '/chain.model.' + str(epoch) + '.tar' th.save(checkpoint, output_file)