Exemplo n.º 1
0
def main():

    # Setup a model
    model_path = None
    enhance_model_path = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/enhance_fbank_train_table_2/model.loss.best'
    #enhance_model_path = '/usr/home/wudamu/Desktop/other_data/model.loss.best.base'

    #asr_mode_path = '/usr/home/wudamu/Desktop/other_data/model.acc.best'
    asr_mode_path = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/asr_clean_train_table3/model.acc.best'
    feat_model_path = asr_mode_path
    #if opt.resume:
    #model_path = os.path.join(opt.works_dir, opt.resume)
    #if os.path.isfile(model_path):
    #package = torch.load(model_path, map_location=lambda storage, loc: storage)
    #enhance_model = EnhanceModel.load_model(model_path, 'enhance_state_dict', opt)
    #feat_model = FbankModel.load_model(model_path, 'fbank_state_dict', opt)
    #asr_model = E2E.load_model(model_path, 'asr_state_dict', opt)
    #else:
    #raise Exception("no checkpoint found at {}".format(opt.resume))
    #else:
    #raise Exception("no checkpoint found at {}".format(opt.resume))
    if opt.resume:
        #model_path = os.path.join(opt.works_dir, opt.resume)

        #package = torch.load(model_path, map_location=lambda storage, loc: storage)
        enhance_model = EnhanceModel.load_model(enhance_model_path,
                                                'enhance_state_dict', opt)
        feat_model = FbankModel.load_model(feat_model_path, 'fbank_state_dict',
                                           opt)
        asr_model = E2E.load_model(asr_mode_path, 'asr_state_dict', opt)

    else:
        raise Exception("no checkpoint found at {}".format(opt.resume))

    def cpu_loader(storage, location):
        return storage

    if opt.lmtype == 'rnnlm':
        # read rnnlm
        if opt.rnnlm:
            rnnlm = lm.ClassifierWithState(
                #lm.RNNLM(len(opt.char_list), 650, 650))
                lm.RNNLM(len(opt.char_list), 300, 650))
            opt.rnnlm = "/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/rnnlm_train_shi/rnnlm.model.best"
            rnnlm.load_state_dict(
                torch.load(opt.rnnlm, map_location=cpu_loader))
            if len(opt.gpu_ids) > 0:
                rnnlm = rnnlm.cuda()
            print('load RNNLM from {}'.format(opt.rnnlm))
            rnnlm.eval()
        else:
            rnnlm = None

        if opt.word_rnnlm:
            if not opt.word_dict:
                logging.error(
                    'word dictionary file is not specified for the word RNNLM.'
                )
                sys.exit(1)

            word_dict = load_labeldict(opt.word_dict)
            char_dict = {x: i for i, x in enumerate(opt.char_list)}
            word_rnnlm = lm.ClassifierWithState(lm.RNNLM(len(word_dict), 650))
            word_rnnlm.load_state_dict(
                torch.load(opt.word_rnnlm, map_location=cpu_loader))
            word_rnnlm.eval()

            if rnnlm is not None:
                rnnlm = lm.ClassifierWithState(
                    extlm.MultiLevelLM(word_rnnlm.predictor, rnnlm.predictor,
                                       word_dict, char_dict))
            else:
                rnnlm = lm.ClassifierWithState(
                    extlm.LookAheadWordLM(word_rnnlm.predictor, word_dict,
                                          char_dict))
        fstlm = None

    elif opt.lmtype == 'fsrnnlm':
        if opt.rnnlm:
            rnnlm = lm.ClassifierWithState(
                fsrnn.FSRNNLM(len(opt.char_list), 300, opt.fast_layers,
                              opt.fast_cell_size, opt.slow_cell_size,
                              opt.zoneout_keep_h, opt.zoneout_keep_c))
            rnnlm.load_state_dict(
                torch.load(opt.rnnlm, map_location=cpu_loader))
            if len(opt.gpu_ids) > 0:
                rnnlm = rnnlm.cuda()
            print('load fsrnn from {}'.format(opt.rnnlm))
            rnnlm.eval()
        else:
            rnnlm = None
            print('not load fsrnn from {}'.format(opt.rnnlm))
        fstlm = None

    elif opt.lmtype == 'fstlm':
        if opt.fstlm_path:
            fstlm = NgramFstLM(opt.fstlm_path, opt.nn_char_map_file, 20)
        else:
            fstlm = None
        rnnlm = None
    else:
        rnnlm = None
        fstlm = None

    #fbank_cmvn_file = os.path.join(opt.exp_path, 'fbank_cmvn.npy')
    fbank_cmvn_file = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/decode_asr_train_table3_1/decode_clean/fbank_cmvn.npy'
    if os.path.exists(fbank_cmvn_file):
        fbank_cmvn = np.load(fbank_cmvn_file)
        fbank_cmvn = torch.FloatTensor(fbank_cmvn)
    else:
        raise Exception("no found at {}".format(fbank_cmvn_file))
    #enhance_cmvn_file ='/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/result_enhance_fbank/enhance_cmvn.npy'
    #enhance_cmvn = np.load(enhance_cmvn_file)
    #enhance_cmvn = torch.FloatTensor(enhance_cmvn)
    torch.set_grad_enabled(False)
    new_json = {}
    for i, (data) in enumerate(recog_loader, start=0):
        utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data
        #utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
        name = utt_ids[0]
        #ss = torch.max(inputs)
        print(name)

        enhance_outputs = enhance_model(inputs, log_inputs, input_sizes)
        #print(enhance_outputs)
        #aa = torch.max(enhance_outputs)

        #feats = feat_model(enhance_outputs, fbank_cmvn)
        enhance_feat = feat_model(enhance_outputs, fbank_cmvn)

        nbest_hyps = asr_model.recognize(enhance_feat,
                                         opt,
                                         opt.char_list,
                                         rnnlm=rnnlm,
                                         fstlm=fstlm)
        #nbest_hyps = asr_model.recognize(enhance_outputs, opt, opt.char_list, rnnlm=rnnlm, fstlm=fstlm)
        # get 1best and remove sos
        y_hat = nbest_hyps[0]['yseq'][1:]
        print(y_hat)
        ##y_true = map(int, targets[0].split())
        y_true = targets

        # print out decoding result
        seq_hat = [opt.char_list[int(idx)] for idx in y_hat]
        seq_true = [opt.char_list[int(idx)] for idx in y_true]
        seq_hat_text = "".join(seq_hat).replace('<space>', ' ')
        seq_true_text = "".join(seq_true).replace('<space>', ' ')
        logging.info("groundtruth[%s]: " + seq_true_text, name)
        logging.info("prediction [%s]: " + seq_hat_text, name)
        # copy old json info
        new_json[name] = dict()
        new_json[name]['utt2spk'] = spk_ids[0]

        # added recognition results to json
        logging.debug("dump token id")
        out_dic = dict()
        out_dic['name'] = 'target1'
        out_dic['text'] = seq_true_text
        out_dic['token'] = " ".join(seq_true)
        out_dic['tokenid'] = " ".join([str(int(idx)) for idx in y_true])

        # TODO(karita) make consistent to chainer as idx[0] not idx
        out_dic['rec_tokenid'] = " ".join([str(int(idx)) for idx in y_hat])
        #logger.debug("dump token")
        out_dic['rec_token'] = " ".join(seq_hat)
        #logger.debug("dump text")
        out_dic['rec_text'] = seq_hat_text

        new_json[name]['output'] = [out_dic]
        # TODO(nelson): Modify this part when saving more than 1 hyp is enabled
        # add n-best recognition results with scores
        if opt.beam_size > 1 and len(nbest_hyps) > 1:
            for i, hyp in enumerate(nbest_hyps):
                y_hat = hyp['yseq'][1:]
                seq_hat = [opt.char_list[int(idx)] for idx in y_hat]
                seq_hat_text = "".join(seq_hat).replace('<space>', ' ')
                new_json[name]['rec_tokenid' + '[' + '{:05d}'.format(i) +
                               ']'] = " ".join([str(idx) for idx in y_hat])
                new_json[name]['rec_token' + '[' + '{:05d}'.format(i) +
                               ']'] = " ".join(seq_hat)
                new_json[name]['rec_text' + '[' + '{:05d}'.format(i) +
                               ']'] = seq_hat_text
                new_json[name]['score' + '[' + '{:05d}'.format(i) +
                               ']'] = float(hyp['score'])
    # TODO(watanabe) fix character coding problems when saving it
    with open(opt.result_label, 'wb') as f:
        f.write(
            json.dumps({
                'utts': new_json
            }, indent=4, sort_keys=True).encode('utf_8'))
Exemplo n.º 2
0
    def __init__(self, args):
        super(E2E, self).__init__()
        self.opt = args
        idim = args.fbank_dim 
        odim = args.odim 
        self.etype = args.etype
        self.verbose = args.verbose
        self.char_list = args.char_list
        ##self.outdir = args.outdir
        self.mtlalpha = args.mtlalpha

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # subsample info
        # +1 means input (+1) and layers outputs (args.elayer)
        subsample = np.ones(args.elayers + 1, dtype=np.int)
        if args.etype == 'blstmp' or args.etype == 'cnnblstmp':
            ss = args.subsample.split("_")
            for j in range(min(args.elayers + 1, len(ss))):
                subsample[j] = int(ss[j])
        else:
            logging.warning(
                'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
        logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
        self.subsample = subsample

        # label smoothing info
        if args.lsm_type:
            logging.info("Use label smoothing with " + args.lsm_type)
            ##labeldist = label_smoothing_dist(odim, args.lsm_type, args.char_list, transcript=args.train_text)
            labeldist = args.labeldist
        else:
            labeldist = None

        # encoder
        self.enc = Encoder(args.etype, idim, args.elayers, args.eunits, args.eprojs,
                           self.subsample, args.subsample_type, args.dropout_rate)
        # ctc
        self.ctc = CTC(odim, args.eprojs, args.dropout_rate)
        # attention
        if args.atype == 'noatt':
            self.att = NoAtt()
        elif args.atype == 'dot':
            self.att = AttDot(args.eprojs, args.dunits, args.adim)
        elif args.atype == 'add':
            self.att = AttAdd(args.eprojs, args.dunits, args.adim)
        elif args.atype == 'location':
            self.att = AttLoc(args.eprojs, args.dunits,
                              args.adim, args.aconv_chans, args.aconv_filts, 'softmax')
        elif args.atype == 'location2d':
            self.att = AttLoc2D(args.eprojs, args.dunits,
                                args.adim, args.awin, args.aconv_chans, args.aconv_filts)
        elif args.atype == 'location_recurrent':
            self.att = AttLocRec(args.eprojs, args.dunits,
                                 args.adim, args.aconv_chans, args.aconv_filts)
        elif args.atype == 'coverage':
            self.att = AttCov(args.eprojs, args.dunits, args.adim)
        elif args.atype == 'coverage_location':
            self.att = AttCovLoc(args.eprojs, args.dunits, args.adim,
                                 args.aconv_chans, args.aconv_filts)
        elif args.atype == 'multi_head_dot':
            self.att = AttMultiHeadDot(args.eprojs, args.dunits,
                                       args.aheads, args.adim, args.adim)
        elif args.atype == 'multi_head_add':
            self.att = AttMultiHeadAdd(args.eprojs, args.dunits,
                                       args.aheads, args.adim, args.adim)
        elif args.atype == 'multi_head_loc':
            self.att = AttMultiHeadLoc(args.eprojs, args.dunits,
                                       args.aheads, args.adim, args.adim,
                                       args.aconv_chans, args.aconv_filts)
        elif args.atype == 'multi_head_multi_res_loc':
            self.att = AttMultiHeadMultiResLoc(args.eprojs, args.dunits,
                                               args.aheads, args.adim, args.adim,
                                               args.aconv_chans, args.aconv_filts)
        else:
            logging.error(
                "Error: need to specify an appropriate attention archtecture")
            sys.exit()
            
        # rnnlm            
        try:
            if args.fusion == 'deep_fusion' or args.fusion == 'cold_fusion':
                if args.lmtype == 'rnnlm' and args.rnnlm:
                    rnnlm = lm.ClassifierWithState(lm.RNNLM(len(args.char_list), 300, 650))                
                    rnnlm.load_state_dict(torch.load(args.rnnlm, map_location=lambda storage, loc: storage))
                    print('load rnnlm from ', args.rnnlm)
                    rnnlm.eval()
                    for p in rnnlm.parameters():
                        p.requires_grad_(False)
                elif args.lmtype == 'fsrnnlm' and args.rnnlm:
                    rnnlm = lm.ClassifierWithState(
                              fsrnn.FSRNNLM(len(args.char_list), 300, args.fast_layers, args.fast_cell_size, 
                              args.slow_cell_size, args.zoneout_keep_h, args.zoneout_keep_c))
                    rnnlm.load_state_dict(torch.load(args.rnnlm, map_location=lambda storage, loc: storage))
                    print('load rnnlm from ', args.rnnlm)
                    rnnlm.eval()
                    for p in rnnlm.parameters():
                        p.requires_grad_(False)
                else:
                    rnnlm = None
            else:
                rnnlm = None
                fusion = None 
                model_unit = 'char'
                space_loss_weight = 0.0
        except:
            rnnlm = None
            fusion = None 
            model_unit = 'char'
            space_loss_weight = 0.0
        # decoder
        self.dec = Decoder(args.eprojs, odim, args.dlayers, args.dunits, self.sos, self.eos, 
                           self.att, self.verbose, self.char_list, labeldist, args.lsm_weight, 
                           fusion, rnnlm, model_unit, space_loss_weight)

        # weight initialization
        self.init_like_chainer()