コード例 #1
0
ファイル: model.py プロジェクト: sxfx0258/rnn-transducer
    def __init__(self, config):
        super(Transducer, self).__init__()
        # define encoder
        self.config = config
        self.encoder = build_encoder(config)
        # define decoder
        self.decoder = build_decoder(config)

        if config.lm_pre_train:
            lm_path = os.path.join(home_dir, config.lm_model_path)
            if os.path.exists(lm_path):
                print('load language model')
                self.decoder.load_state_dict(torch.load(lm_path), strict=False)

        if config.ctc_pre_train:
            ctc_path = os.path.join(home_dir, config.ctc_model_path)
            if os.path.exists(ctc_path):
                print('load ctc pretrain model')
                self.encoder.load_state_dict(torch.load(ctc_path),
                                             strict=False)

        # define JointNet
        self.joint = JointNet(input_size=config.joint.input_size,
                              inner_dim=config.joint.inner_size,
                              vocab_size=config.vocab_size)

        if config.share_embedding:
            assert self.decoder.embedding.weight.size(
            ) == self.joint.project_layer.weight.size(), '%d != %d' % (
                self.decoder.embedding.weight.size(1),
                self.joint.project_layer.weight.size(1))
            self.joint.project_layer.weight = self.decoder.embedding.weight

        self.crit = RNNTLoss()
コード例 #2
0
    def __init__(self, config):
        super(Pre_encoder, self).__init__()
        # define encoder
        self.config = config

        # self.encoder = BuildEncoder(config)
        self.encoder = build_encoder(config)
        self.project_layer = nn.Linear(800, 2664)

        self.crit = CTCLoss()
コード例 #3
0
 def __init__(self, config):
     super(BRNNCTC, self).__init__()
     # define encoder
     self.config = config
     self.encoder = build_encoder(config)
     self.forward_layer = nn.Linear(config.enc.output_size,
                                    config.joint.inner_size,
                                    bias=True)
     self.tanh = nn.Tanh()
     self.base_project_layer = nn.Linear(config.joint.inner_size,
                                         config.vocab_size,
                                         bias=True)
     self.rle_project_layer = nn.Linear(
         config.joint.inner_size, config.vocab_size,
         bias=True)  # 0 is blank, A:1~max_rle,C:max_rle+1~2*max_rle...
コード例 #4
0
    def __init__(self, config):
        super(Transducer, self).__init__()
        # define encoder
        self.config = config
        self.encoder = build_encoder(config)
        # define decoder
        self.decoder = build_decoder(config)
        # define JointNet
        self.joint = JointNet(
            input_size=config.joint.input_size,
            inner_dim=config.joint.inner_size,
            vocab_size=config.vocab_size,
            max_rle=config.max_rle
        )

        if config.share_embedding:
            assert self.decoder.embedding.weight.size() == self.joint.base_project_layer.weight.size(), '%d != %d' % (
                self.decoder.embedding.weight.size(1), self.joint.base_project_layer.weight.size(1))
            self.joint.base_project_layer.weight = self.decoder.embedding.weight
コード例 #5
0
    def __init__(self, config):
        super(Transducer, self).__init__()
        # define encoder
        self.config = config
        self.encoder = build_encoder(config)
        # define decoder
        self.decoder = build_decoder(config)
        # define JointNet (640,512,4232), enc (160,320), dec (512, 320)
        self.joint = JointNet(input_size=config.joint.input_size,
                              inner_dim=config.joint.inner_size,
                              vocab_size=config.vocab_size)

        if config.share_embedding:
            assert self.decoder.embedding.weight.size(
            ) == self.joint.project_layer.weight.size(), '%d != %d' % (
                self.decoder.embedding.weight.size(1),
                self.joint.project_layer.weight.size(1))
            self.joint.project_layer.weight = self.decoder.embedding.weight

        self.crit = RNNTLoss()
コード例 #6
0
    def __init__(self, config):
        super(Transducer, self).__init__()
        # define encoder
        self.config = config

        # self.encoder = BuildEncoder(config)
        self.encoder = build_encoder(config)
        self.project_layer = nn.Linear(320, 213)
        # define decoder
        self.decoder = build_decoder(config)
        # define JointNet
        self.joint = JointNet(input_size=config.joint.input_size,
                              inner_dim=config.joint.inner_size,
                              vocab_size=config.vocab_size)

        if config.share_embedding:
            assert self.decoder.embedding.weight.size(
            ) == self.joint.project_layer.weight.size(), '%d != %d' % (
                self.decoder.embedding.weight.size(1),
                self.joint.project_layer.weight.size(1))
            self.joint.project_layer.weight = self.decoder.embedding.weight
        #self.ctc_crit = CTCLoss()
        self.rnnt_crit = RNNTLoss()
コード例 #7
0
    # 是否继续训练
    continue_train = True
    epochs = 100
    batch_size = 64
    learning_rate = 0.0001
    device = torch.device('cuda')

    opt = parser.parse_args()
    configfile = open(opt.config)

    config = AttrDict(yaml.load(configfile, Loader=yaml.FullLoader))
    # ==========================================
    # NETWORK SETTING
    # ==========================================
    # load model
    model = build_encoder(config.model)
    if continue_train:
        print('load ctc pretrain model')
        ctc_path = os.path.join(home_dir, 'ctc_model/44_0.1983_enecoder_model')
        model.load_state_dict(torch.load(ctc_path), strict=False)

    print(model)
    model = model.cuda(device)

    # 数据提取
    ctc_loss = torch.nn.CTCLoss()
    train_dataset = AudioDataset(config.data, 'train')
    training_data = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=batch_size,
                                                shuffle=config.data.shuffle,
                                                num_workers=32,