def forward(self, state, x): """Forward neural networks.""" if state is None: h = [ to_device(x, self.zero_state(x.size(0))) for n in range(self.n_layers) ] state = {"h": h} if self.typ == "lstm": c = [ to_device(x, self.zero_state(x.size(0))) for n in range(self.n_layers) ] state = {"c": c, "h": h} h = [None] * self.n_layers emb = self.embed(x) if self.typ == "lstm": c = [None] * self.n_layers h[0], c[0] = self.rnn[0](self.dropout[0](emb), (state["h"][0], state["c"][0])) for n in range(1, self.n_layers): h[n], c[n] = self.rnn[n](self.dropout[n](h[n - 1]), (state["h"][n], state["c"][n])) state = {"c": c, "h": h} else: h[0] = self.rnn[0](self.dropout[0](emb), state["h"][0]) for n in range(1, self.n_layers): h[n] = self.rnn[n](self.dropout[n](h[n - 1]), state["h"][n]) state = {"h": h} y = self.lo(self.dropout[-1](h[-1])) return state, y
def init_state(self, init_tensor): """Initialize decoder states. Args: init_tensor (torch.Tensor): batch of input features (B, emb_dim / dec_dim) Returns: (tuple): batch of decoder states ([L x (B, dec_dim)], [L x (B, dec_dim)]) """ dtype = init_tensor.dtype z_list = [ to_device(init_tensor, torch.zeros(init_tensor.size(0), self.dunits)).to( dtype ) for _ in range(self.dlayers) ] c_list = [ to_device(init_tensor, torch.zeros(init_tensor.size(0), self.dunits)).to( dtype ) for _ in range(self.dlayers) ] return (z_list, c_list)
def forward(self, state, x): # update state with input label x if state is None: # make initial states and log-prob vectors self.var_word_eos = to_device(x, self.var_word_eos) self.var_word_unk = to_device(x, self.var_word_eos) wlm_state, z_wlm = self.wordlm(None, self.var_word_eos) wlm_logprobs = F.log_softmax(z_wlm, dim=1) clm_state, z_clm = self.subwordlm(None, x) log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight new_node = self.lexroot clm_logprob = 0.0 xi = self.space else: clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state xi = int(x) if xi == self.space: # inter-word transition if node is not None and node[ 1] >= 0: # check if the node is word end w = to_device(x, torch.LongTensor([node[1]])) else: # this node is not a word end, which means <unk> w = self.var_word_unk # update wordlm state and log-prob vector wlm_state, z_wlm = self.wordlm(wlm_state, w) wlm_logprobs = F.log_softmax(z_wlm, dim=1) new_node = self.lexroot # move to the tree root clm_logprob = 0.0 elif node is not None and xi in node[0]: # intra-word transition new_node = node[0][xi] clm_logprob += log_y[0, xi] elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode new_node = None clm_logprob += log_y[0, xi] else: # if open_vocab flag is disabled, return 0 probabilities log_y = to_device( x, torch.full((1, self.subword_dict_size), self.logzero)) return (clm_state, wlm_state, wlm_logprobs, None, log_y, 0.0), log_y clm_state, z_clm = self.subwordlm(clm_state, x) log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight # apply word-level probabilies for <space> and <eos> labels if xi != self.space: if new_node is not None and new_node[ 1] >= 0: # if new node is word end wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob else: wlm_logprob = wlm_logprobs[:, self. word_unk] + self.log_oov_penalty log_y[:, self.space] = wlm_logprob log_y[:, self.eos] = wlm_logprob else: log_y[:, self.space] = self.logzero log_y[:, self.eos] = self.logzero return ( (clm_state, wlm_state, wlm_logprobs, new_node, log_y, float(clm_logprob)), log_y, )
def score(self, hyp, cache, init_tensor=None): """Forward one step. Args: hyp (dataclass): hypothesis cache (dict): states cache Returns: y (torch.Tensor): decoder outputs (1, dec_dim) state (tuple): decoder states ([L x (1, dec_dim)], [L x (1, dec_dim)]), (torch.Tensor): token id for LM (1) """ vy = to_device(self, torch.full((1, 1), hyp.yseq[-1], dtype=torch.long)) str_yseq = "".join([str(x) for x in hyp.yseq]) if str_yseq in cache: y, state = cache[str_yseq] else: ey = self.embed(vy) y, state = self.rnn_forward(ey[0], hyp.dec_state) cache[str_yseq] = (y, state) return y, state, vy[0]
def final(self, state): wlm_state, cumsum_probs, node = state if node is not None and node[1] >= 0: # check if the node is word end w = to_device(cumsum_probs, torch.LongTensor([node[1]])) else: # this node is not a word end, which means <unk> w = self.var_word_unk wlm_state, z_wlm = self.wordlm(wlm_state, w) return float(F.log_softmax(z_wlm, dim=1)[:, self.word_eos])
def score(self, hyp, cache, init_tensor=None): """Forward one step. Args: hyp (dataclass): hypothesis cache (dict): states cache Returns: y (torch.Tensor): decoder outputs (1, dec_dim) (list): decoder and attention states [L x (1, max_len, dec_dim)] lm_tokens (torch.Tensor): token id for LM (1) """ tgt = to_device(self, torch.tensor(hyp.yseq).unsqueeze(0)) lm_tokens = tgt[:, -1] str_yseq = "".join([str(x) for x in hyp.yseq]) if str_yseq in cache: y, new_state = cache[str_yseq] else: tgt_mask = to_device(self, subsequent_mask(len(hyp.yseq)).unsqueeze(0)) state = check_state(hyp.dec_state, (tgt.size(1) - 1), self.blank) tgt = self.embed(tgt) new_state = [] for s, decoder in zip(state, self.decoders): tgt, tgt_mask = decoder(tgt, tgt_mask, cache=s) new_state.append(tgt) y = self.after_norm(tgt[:, -1]) cache[str_yseq] = (y, new_state) return y, new_state, lm_tokens
def forward(self, xs_pad, ilens, prev_states=None): """Encoder forward :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...) :return: batch of hidden state sequences (B, Tmax, eprojs) :rtype: torch.Tensor """ if prev_states is None: prev_states = [None] * len(self.enc) assert len(prev_states) == len(self.enc) current_states = [] for module, prev_state in zip(self.enc, prev_states): xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state) current_states.append(states) # make mask to remove bias value in padded part mask = to_device(xs_pad, make_pad_mask(ilens).unsqueeze(-1)) return xs_pad.masked_fill(mask, 0.0), ilens, current_states
def batch_score(self, hyps, batch_states, cache, init_tensor=None): """Forward batch one step. Args: hyps (list): batch of hypotheses batch_states (list): decoder states [L x (B, max_len, dec_dim)] cache (dict): states cache Returns: batch_y (torch.Tensor): decoder output (B, dec_dim) batch_states (list): decoder states [L x (B, max_len, dec_dim)] lm_tokens (torch.Tensor): batch of token ids for LM (B) """ final_batch = len(hyps) tokens = [] process = [] done = [None for _ in range(final_batch)] for i, hyp in enumerate(hyps): str_yseq = "".join([str(x) for x in hyp.yseq]) if str_yseq in cache: done[i] = (*cache[str_yseq], hyp.yseq) else: tokens.append(hyp.yseq) process.append((str_yseq, hyp.dec_state, hyp.yseq)) if process: batch = len(tokens) tokens = pad_sequence(tokens, self.blank) b_tokens = to_device(self, torch.LongTensor(tokens).view(batch, -1)) tgt_mask = to_device( self, subsequent_mask(b_tokens.size(-1)).unsqueeze(0).expand(batch, -1, -1), ) dec_state = self.init_state() dec_state = self.create_batch_states( dec_state, [p[1] for p in process], tokens, ) tgt = self.embed(b_tokens) next_state = [] for s, decoder in zip(dec_state, self.decoders): tgt, tgt_mask = decoder(tgt, tgt_mask, cache=s) next_state.append(tgt) tgt = self.after_norm(tgt[:, -1]) j = 0 for i in range(final_batch): if done[i] is None: new_state = self.select_state(next_state, j) done[i] = (tgt[j], new_state, process[j][2]) cache[process[j][0]] = (tgt[j], new_state) j += 1 batch_states = self.create_batch_states( batch_states, [d[1] for d in done], [d[2] for d in done] ) batch_y = torch.stack([d[0] for d in done]) lm_tokens = to_device( self, torch.LongTensor([h.yseq[-1] for h in hyps]).view(final_batch) ) return batch_y, batch_states, lm_tokens
def forward(self, state, x): # update state with input label x if state is None: # make initial states and cumlative probability vector self.var_word_eos = to_device(x, self.var_word_eos) self.var_word_unk = to_device(x, self.var_word_eos) self.zero_tensor = to_device(x, self.zero_tensor) wlm_state, z_wlm = self.wordlm(None, self.var_word_eos) cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1) new_node = self.lexroot xi = self.space else: wlm_state, cumsum_probs, node = state xi = int(x) if xi == self.space: # inter-word transition if node is not None and node[ 1] >= 0: # check if the node is word end w = to_device(x, torch.LongTensor([node[1]])) else: # this node is not a word end, which means <unk> w = self.var_word_unk # update wordlm state and cumlative probability vector wlm_state, z_wlm = self.wordlm(wlm_state, w) cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1) new_node = self.lexroot # move to the tree root elif node is not None and xi in node[0]: # intra-word transition new_node = node[0][xi] elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode new_node = None else: # if open_vocab flag is disabled, return 0 probabilities log_y = to_device( x, torch.full((1, self.subword_dict_size), self.logzero)) return (wlm_state, None, None), log_y if new_node is not None: succ, wid, wids = new_node # compute parent node probability sum_prob = ((cumsum_probs[:, wids[1]] - cumsum_probs[:, wids[0]]) if wids is not None else 1.0) if sum_prob < self.zero: log_y = to_device( x, torch.full((1, self.subword_dict_size), self.logzero)) return (wlm_state, cumsum_probs, new_node), log_y # set <unk> probability as a default value unk_prob = (cumsum_probs[:, self.word_unk] - cumsum_probs[:, self.word_unk - 1]) y = to_device( x, torch.full((1, self.subword_dict_size), float(unk_prob) * self.oov_penalty), ) # compute transition probabilities to child nodes for cid, nd in succ.items(): y[:, cid] = (cumsum_probs[:, nd[2][1]] - cumsum_probs[:, nd[2][0]]) / sum_prob # apply word-level probabilies for <space> and <eos> labels if wid >= 0: wlm_prob = (cumsum_probs[:, wid] - cumsum_probs[:, wid - 1]) / sum_prob y[:, self.space] = wlm_prob y[:, self.eos] = wlm_prob elif xi == self.space: y[:, self.space] = self.zero y[:, self.eos] = self.zero log_y = torch.log(torch.max( y, self.zero_tensor)) # clip to avoid log(0) else: # if no path in the tree, transition probability is one log_y = to_device(x, torch.zeros(1, self.subword_dict_size)) return (wlm_state, cumsum_probs, new_node), log_y
def batch_score(self, hyps, batch_states, cache, init_tensor=None): """Forward batch one step. Args: hyps (list): batch of hypotheses batch_states (tuple): batch of decoder states ([L x (B, dec_dim)], [L x (B, dec_dim)]) cache (dict): states cache Returns: batch_y (torch.Tensor): decoder output (B, dec_dim) batch_states (tuple): batch of decoder states ([L x (B, dec_dim)], [L x (B, dec_dim)]) lm_tokens (torch.Tensor): batch of token ids for LM (B) """ final_batch = len(hyps) tokens = [] process = [] done = [None for _ in range(final_batch)] for i, hyp in enumerate(hyps): str_yseq = "".join([str(x) for x in hyp.yseq]) if str_yseq in cache: done[i] = cache[str_yseq] else: tokens.append(hyp.yseq[-1]) process.append((str_yseq, hyp.dec_state)) if process: batch = len(process) tokens = to_device(self, torch.LongTensor(tokens).view(batch)) dec_state = self.init_state(torch.zeros((batch, self.dunits))) dec_state = self.create_batch_states(dec_state, [p[1] for p in process]) ey = self.embed(tokens) y, dec_state = self.rnn_forward(ey, dec_state) j = 0 for i in range(final_batch): if done[i] is None: new_state = self.select_state(dec_state, j) done[i] = (y[j], new_state) cache[process[j][0]] = (y[j], new_state) j += 1 batch_states = self.create_batch_states(batch_states, [d[1] for d in done]) batch_y = torch.stack([d[0] for d in done]) lm_tokens = to_device( self, torch.LongTensor([h.yseq[-1] for h in hyps]).view(final_batch) ) return batch_y, batch_states, lm_tokens