def main(): # Setup a model model_path = None enhance_model_path = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/enhance_fbank_train_table_2/model.loss.best' #enhance_model_path = '/usr/home/wudamu/Desktop/other_data/model.loss.best.base' #asr_mode_path = '/usr/home/wudamu/Desktop/other_data/model.acc.best' asr_mode_path = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/asr_clean_train_table3/model.acc.best' feat_model_path = asr_mode_path #if opt.resume: #model_path = os.path.join(opt.works_dir, opt.resume) #if os.path.isfile(model_path): #package = torch.load(model_path, map_location=lambda storage, loc: storage) #enhance_model = EnhanceModel.load_model(model_path, 'enhance_state_dict', opt) #feat_model = FbankModel.load_model(model_path, 'fbank_state_dict', opt) #asr_model = E2E.load_model(model_path, 'asr_state_dict', opt) #else: #raise Exception("no checkpoint found at {}".format(opt.resume)) #else: #raise Exception("no checkpoint found at {}".format(opt.resume)) if opt.resume: #model_path = os.path.join(opt.works_dir, opt.resume) #package = torch.load(model_path, map_location=lambda storage, loc: storage) enhance_model = EnhanceModel.load_model(enhance_model_path, 'enhance_state_dict', opt) feat_model = FbankModel.load_model(feat_model_path, 'fbank_state_dict', opt) asr_model = E2E.load_model(asr_mode_path, 'asr_state_dict', opt) else: raise Exception("no checkpoint found at {}".format(opt.resume)) def cpu_loader(storage, location): return storage if opt.lmtype == 'rnnlm': # read rnnlm if opt.rnnlm: rnnlm = lm.ClassifierWithState( #lm.RNNLM(len(opt.char_list), 650, 650)) lm.RNNLM(len(opt.char_list), 300, 650)) opt.rnnlm = "/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/rnnlm_train_shi/rnnlm.model.best" rnnlm.load_state_dict( torch.load(opt.rnnlm, map_location=cpu_loader)) if len(opt.gpu_ids) > 0: rnnlm = rnnlm.cuda() print('load RNNLM from {}'.format(opt.rnnlm)) rnnlm.eval() else: rnnlm = None if opt.word_rnnlm: if not opt.word_dict: logging.error( 'word dictionary file is not specified for the word RNNLM.' ) sys.exit(1) word_dict = load_labeldict(opt.word_dict) char_dict = {x: i for i, x in enumerate(opt.char_list)} word_rnnlm = lm.ClassifierWithState(lm.RNNLM(len(word_dict), 650)) word_rnnlm.load_state_dict( torch.load(opt.word_rnnlm, map_location=cpu_loader)) word_rnnlm.eval() if rnnlm is not None: rnnlm = lm.ClassifierWithState( extlm.MultiLevelLM(word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict)) else: rnnlm = lm.ClassifierWithState( extlm.LookAheadWordLM(word_rnnlm.predictor, word_dict, char_dict)) fstlm = None elif opt.lmtype == 'fsrnnlm': if opt.rnnlm: rnnlm = lm.ClassifierWithState( fsrnn.FSRNNLM(len(opt.char_list), 300, opt.fast_layers, opt.fast_cell_size, opt.slow_cell_size, opt.zoneout_keep_h, opt.zoneout_keep_c)) rnnlm.load_state_dict( torch.load(opt.rnnlm, map_location=cpu_loader)) if len(opt.gpu_ids) > 0: rnnlm = rnnlm.cuda() print('load fsrnn from {}'.format(opt.rnnlm)) rnnlm.eval() else: rnnlm = None print('not load fsrnn from {}'.format(opt.rnnlm)) fstlm = None elif opt.lmtype == 'fstlm': if opt.fstlm_path: fstlm = NgramFstLM(opt.fstlm_path, opt.nn_char_map_file, 20) else: fstlm = None rnnlm = None else: rnnlm = None fstlm = None #fbank_cmvn_file = os.path.join(opt.exp_path, 'fbank_cmvn.npy') fbank_cmvn_file = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/decode_asr_train_table3_1/decode_clean/fbank_cmvn.npy' if os.path.exists(fbank_cmvn_file): fbank_cmvn = np.load(fbank_cmvn_file) fbank_cmvn = torch.FloatTensor(fbank_cmvn) else: raise Exception("no found at {}".format(fbank_cmvn_file)) #enhance_cmvn_file ='/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/result_enhance_fbank/enhance_cmvn.npy' #enhance_cmvn = np.load(enhance_cmvn_file) #enhance_cmvn = torch.FloatTensor(enhance_cmvn) torch.set_grad_enabled(False) new_json = {} for i, (data) in enumerate(recog_loader, start=0): utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data #utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data name = utt_ids[0] #ss = torch.max(inputs) print(name) enhance_outputs = enhance_model(inputs, log_inputs, input_sizes) #print(enhance_outputs) #aa = torch.max(enhance_outputs) #feats = feat_model(enhance_outputs, fbank_cmvn) enhance_feat = feat_model(enhance_outputs, fbank_cmvn) nbest_hyps = asr_model.recognize(enhance_feat, opt, opt.char_list, rnnlm=rnnlm, fstlm=fstlm) #nbest_hyps = asr_model.recognize(enhance_outputs, opt, opt.char_list, rnnlm=rnnlm, fstlm=fstlm) # get 1best and remove sos y_hat = nbest_hyps[0]['yseq'][1:] print(y_hat) ##y_true = map(int, targets[0].split()) y_true = targets # print out decoding result seq_hat = [opt.char_list[int(idx)] for idx in y_hat] seq_true = [opt.char_list[int(idx)] for idx in y_true] seq_hat_text = "".join(seq_hat).replace('<space>', ' ') seq_true_text = "".join(seq_true).replace('<space>', ' ') logging.info("groundtruth[%s]: " + seq_true_text, name) logging.info("prediction [%s]: " + seq_hat_text, name) # copy old json info new_json[name] = dict() new_json[name]['utt2spk'] = spk_ids[0] # added recognition results to json logging.debug("dump token id") out_dic = dict() out_dic['name'] = 'target1' out_dic['text'] = seq_true_text out_dic['token'] = " ".join(seq_true) out_dic['tokenid'] = " ".join([str(int(idx)) for idx in y_true]) # TODO(karita) make consistent to chainer as idx[0] not idx out_dic['rec_tokenid'] = " ".join([str(int(idx)) for idx in y_hat]) #logger.debug("dump token") out_dic['rec_token'] = " ".join(seq_hat) #logger.debug("dump text") out_dic['rec_text'] = seq_hat_text new_json[name]['output'] = [out_dic] # TODO(nelson): Modify this part when saving more than 1 hyp is enabled # add n-best recognition results with scores if opt.beam_size > 1 and len(nbest_hyps) > 1: for i, hyp in enumerate(nbest_hyps): y_hat = hyp['yseq'][1:] seq_hat = [opt.char_list[int(idx)] for idx in y_hat] seq_hat_text = "".join(seq_hat).replace('<space>', ' ') new_json[name]['rec_tokenid' + '[' + '{:05d}'.format(i) + ']'] = " ".join([str(idx) for idx in y_hat]) new_json[name]['rec_token' + '[' + '{:05d}'.format(i) + ']'] = " ".join(seq_hat) new_json[name]['rec_text' + '[' + '{:05d}'.format(i) + ']'] = seq_hat_text new_json[name]['score' + '[' + '{:05d}'.format(i) + ']'] = float(hyp['score']) # TODO(watanabe) fix character coding problems when saving it with open(opt.result_label, 'wb') as f: f.write( json.dumps({ 'utts': new_json }, indent=4, sort_keys=True).encode('utf_8'))
def __init__(self, args): super(E2E, self).__init__() self.opt = args idim = args.fbank_dim odim = args.odim self.etype = args.etype self.verbose = args.verbose self.char_list = args.char_list ##self.outdir = args.outdir self.mtlalpha = args.mtlalpha # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 # subsample info # +1 means input (+1) and layers outputs (args.elayer) subsample = np.ones(args.elayers + 1, dtype=np.int) if args.etype == 'blstmp' or args.etype == 'cnnblstmp': ss = args.subsample.split("_") for j in range(min(args.elayers + 1, len(ss))): subsample[j] = int(ss[j]) else: logging.warning( 'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.') logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) self.subsample = subsample # label smoothing info if args.lsm_type: logging.info("Use label smoothing with " + args.lsm_type) ##labeldist = label_smoothing_dist(odim, args.lsm_type, args.char_list, transcript=args.train_text) labeldist = args.labeldist else: labeldist = None # encoder self.enc = Encoder(args.etype, idim, args.elayers, args.eunits, args.eprojs, self.subsample, args.subsample_type, args.dropout_rate) # ctc self.ctc = CTC(odim, args.eprojs, args.dropout_rate) # attention if args.atype == 'noatt': self.att = NoAtt() elif args.atype == 'dot': self.att = AttDot(args.eprojs, args.dunits, args.adim) elif args.atype == 'add': self.att = AttAdd(args.eprojs, args.dunits, args.adim) elif args.atype == 'location': self.att = AttLoc(args.eprojs, args.dunits, args.adim, args.aconv_chans, args.aconv_filts, 'softmax') elif args.atype == 'location2d': self.att = AttLoc2D(args.eprojs, args.dunits, args.adim, args.awin, args.aconv_chans, args.aconv_filts) elif args.atype == 'location_recurrent': self.att = AttLocRec(args.eprojs, args.dunits, args.adim, args.aconv_chans, args.aconv_filts) elif args.atype == 'coverage': self.att = AttCov(args.eprojs, args.dunits, args.adim) elif args.atype == 'coverage_location': self.att = AttCovLoc(args.eprojs, args.dunits, args.adim, args.aconv_chans, args.aconv_filts) elif args.atype == 'multi_head_dot': self.att = AttMultiHeadDot(args.eprojs, args.dunits, args.aheads, args.adim, args.adim) elif args.atype == 'multi_head_add': self.att = AttMultiHeadAdd(args.eprojs, args.dunits, args.aheads, args.adim, args.adim) elif args.atype == 'multi_head_loc': self.att = AttMultiHeadLoc(args.eprojs, args.dunits, args.aheads, args.adim, args.adim, args.aconv_chans, args.aconv_filts) elif args.atype == 'multi_head_multi_res_loc': self.att = AttMultiHeadMultiResLoc(args.eprojs, args.dunits, args.aheads, args.adim, args.adim, args.aconv_chans, args.aconv_filts) else: logging.error( "Error: need to specify an appropriate attention archtecture") sys.exit() # rnnlm try: if args.fusion == 'deep_fusion' or args.fusion == 'cold_fusion': if args.lmtype == 'rnnlm' and args.rnnlm: rnnlm = lm.ClassifierWithState(lm.RNNLM(len(args.char_list), 300, 650)) rnnlm.load_state_dict(torch.load(args.rnnlm, map_location=lambda storage, loc: storage)) print('load rnnlm from ', args.rnnlm) rnnlm.eval() for p in rnnlm.parameters(): p.requires_grad_(False) elif args.lmtype == 'fsrnnlm' and args.rnnlm: rnnlm = lm.ClassifierWithState( fsrnn.FSRNNLM(len(args.char_list), 300, args.fast_layers, args.fast_cell_size, args.slow_cell_size, args.zoneout_keep_h, args.zoneout_keep_c)) rnnlm.load_state_dict(torch.load(args.rnnlm, map_location=lambda storage, loc: storage)) print('load rnnlm from ', args.rnnlm) rnnlm.eval() for p in rnnlm.parameters(): p.requires_grad_(False) else: rnnlm = None else: rnnlm = None fusion = None model_unit = 'char' space_loss_weight = 0.0 except: rnnlm = None fusion = None model_unit = 'char' space_loss_weight = 0.0 # decoder self.dec = Decoder(args.eprojs, odim, args.dlayers, args.dunits, self.sos, self.eos, self.att, self.verbose, self.char_list, labeldist, args.lsm_weight, fusion, rnnlm, model_unit, space_loss_weight) # weight initialization self.init_like_chainer()