def __init__(self, config): super(Transducer, self).__init__() # define encoder self.config = config self.encoder = build_encoder(config) # define decoder self.decoder = build_decoder(config) if config.lm_pre_train: lm_path = os.path.join(home_dir, config.lm_model_path) if os.path.exists(lm_path): print('load language model') self.decoder.load_state_dict(torch.load(lm_path), strict=False) if config.ctc_pre_train: ctc_path = os.path.join(home_dir, config.ctc_model_path) if os.path.exists(ctc_path): print('load ctc pretrain model') self.encoder.load_state_dict(torch.load(ctc_path), strict=False) # define JointNet self.joint = JointNet(input_size=config.joint.input_size, inner_dim=config.joint.inner_size, vocab_size=config.vocab_size) if config.share_embedding: assert self.decoder.embedding.weight.size( ) == self.joint.project_layer.weight.size(), '%d != %d' % ( self.decoder.embedding.weight.size(1), self.joint.project_layer.weight.size(1)) self.joint.project_layer.weight = self.decoder.embedding.weight self.crit = RNNTLoss()
def __init__(self, config): super(Pre_encoder, self).__init__() # define encoder self.config = config # self.encoder = BuildEncoder(config) self.encoder = build_encoder(config) self.project_layer = nn.Linear(800, 2664) self.crit = CTCLoss()
def __init__(self, config): super(BRNNCTC, self).__init__() # define encoder self.config = config self.encoder = build_encoder(config) self.forward_layer = nn.Linear(config.enc.output_size, config.joint.inner_size, bias=True) self.tanh = nn.Tanh() self.base_project_layer = nn.Linear(config.joint.inner_size, config.vocab_size, bias=True) self.rle_project_layer = nn.Linear( config.joint.inner_size, config.vocab_size, bias=True) # 0 is blank, A:1~max_rle,C:max_rle+1~2*max_rle...
def __init__(self, config): super(Transducer, self).__init__() # define encoder self.config = config self.encoder = build_encoder(config) # define decoder self.decoder = build_decoder(config) # define JointNet self.joint = JointNet( input_size=config.joint.input_size, inner_dim=config.joint.inner_size, vocab_size=config.vocab_size, max_rle=config.max_rle ) if config.share_embedding: assert self.decoder.embedding.weight.size() == self.joint.base_project_layer.weight.size(), '%d != %d' % ( self.decoder.embedding.weight.size(1), self.joint.base_project_layer.weight.size(1)) self.joint.base_project_layer.weight = self.decoder.embedding.weight
def __init__(self, config): super(Transducer, self).__init__() # define encoder self.config = config self.encoder = build_encoder(config) # define decoder self.decoder = build_decoder(config) # define JointNet (640,512,4232), enc (160,320), dec (512, 320) self.joint = JointNet(input_size=config.joint.input_size, inner_dim=config.joint.inner_size, vocab_size=config.vocab_size) if config.share_embedding: assert self.decoder.embedding.weight.size( ) == self.joint.project_layer.weight.size(), '%d != %d' % ( self.decoder.embedding.weight.size(1), self.joint.project_layer.weight.size(1)) self.joint.project_layer.weight = self.decoder.embedding.weight self.crit = RNNTLoss()
def __init__(self, config): super(Transducer, self).__init__() # define encoder self.config = config # self.encoder = BuildEncoder(config) self.encoder = build_encoder(config) self.project_layer = nn.Linear(320, 213) # define decoder self.decoder = build_decoder(config) # define JointNet self.joint = JointNet(input_size=config.joint.input_size, inner_dim=config.joint.inner_size, vocab_size=config.vocab_size) if config.share_embedding: assert self.decoder.embedding.weight.size( ) == self.joint.project_layer.weight.size(), '%d != %d' % ( self.decoder.embedding.weight.size(1), self.joint.project_layer.weight.size(1)) self.joint.project_layer.weight = self.decoder.embedding.weight #self.ctc_crit = CTCLoss() self.rnnt_crit = RNNTLoss()
# 是否继续训练 continue_train = True epochs = 100 batch_size = 64 learning_rate = 0.0001 device = torch.device('cuda') opt = parser.parse_args() configfile = open(opt.config) config = AttrDict(yaml.load(configfile, Loader=yaml.FullLoader)) # ========================================== # NETWORK SETTING # ========================================== # load model model = build_encoder(config.model) if continue_train: print('load ctc pretrain model') ctc_path = os.path.join(home_dir, 'ctc_model/44_0.1983_enecoder_model') model.load_state_dict(torch.load(ctc_path), strict=False) print(model) model = model.cuda(device) # 数据提取 ctc_loss = torch.nn.CTCLoss() train_dataset = AudioDataset(config.data, 'train') training_data = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=config.data.shuffle, num_workers=32,