def forward(self, state, x):
        """Forward neural networks."""
        if state is None:
            h = [
                to_device(x, self.zero_state(x.size(0)))
                for n in range(self.n_layers)
            ]
            state = {"h": h}
            if self.typ == "lstm":
                c = [
                    to_device(x, self.zero_state(x.size(0)))
                    for n in range(self.n_layers)
                ]
                state = {"c": c, "h": h}

        h = [None] * self.n_layers
        emb = self.embed(x)
        if self.typ == "lstm":
            c = [None] * self.n_layers
            h[0], c[0] = self.rnn[0](self.dropout[0](emb),
                                     (state["h"][0], state["c"][0]))
            for n in range(1, self.n_layers):
                h[n], c[n] = self.rnn[n](self.dropout[n](h[n - 1]),
                                         (state["h"][n], state["c"][n]))
            state = {"c": c, "h": h}
        else:
            h[0] = self.rnn[0](self.dropout[0](emb), state["h"][0])
            for n in range(1, self.n_layers):
                h[n] = self.rnn[n](self.dropout[n](h[n - 1]), state["h"][n])
            state = {"h": h}
        y = self.lo(self.dropout[-1](h[-1]))
        return state, y
Beispiel #2
0
    def init_state(self, init_tensor):
        """Initialize decoder states.

        Args:
            init_tensor (torch.Tensor): batch of input features
                (B, emb_dim / dec_dim)

        Returns:
            (tuple): batch of decoder states
                ([L x (B, dec_dim)], [L x (B, dec_dim)])

        """
        dtype = init_tensor.dtype
        z_list = [
            to_device(init_tensor, torch.zeros(init_tensor.size(0), self.dunits)).to(
                dtype
            )
            for _ in range(self.dlayers)
        ]
        c_list = [
            to_device(init_tensor, torch.zeros(init_tensor.size(0), self.dunits)).to(
                dtype
            )
            for _ in range(self.dlayers)
        ]

        return (z_list, c_list)
    def forward(self, state, x):
        # update state with input label x
        if state is None:  # make initial states and log-prob vectors
            self.var_word_eos = to_device(x, self.var_word_eos)
            self.var_word_unk = to_device(x, self.var_word_eos)
            wlm_state, z_wlm = self.wordlm(None, self.var_word_eos)
            wlm_logprobs = F.log_softmax(z_wlm, dim=1)
            clm_state, z_clm = self.subwordlm(None, x)
            log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight
            new_node = self.lexroot
            clm_logprob = 0.0
            xi = self.space
        else:
            clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state
            xi = int(x)
            if xi == self.space:  # inter-word transition
                if node is not None and node[
                        1] >= 0:  # check if the node is word end
                    w = to_device(x, torch.LongTensor([node[1]]))
                else:  # this node is not a word end, which means <unk>
                    w = self.var_word_unk
                # update wordlm state and log-prob vector
                wlm_state, z_wlm = self.wordlm(wlm_state, w)
                wlm_logprobs = F.log_softmax(z_wlm, dim=1)
                new_node = self.lexroot  # move to the tree root
                clm_logprob = 0.0
            elif node is not None and xi in node[0]:  # intra-word transition
                new_node = node[0][xi]
                clm_logprob += log_y[0, xi]
            elif self.open_vocab:  # if no path in the tree, enter open-vocabulary mode
                new_node = None
                clm_logprob += log_y[0, xi]
            else:  # if open_vocab flag is disabled, return 0 probabilities
                log_y = to_device(
                    x, torch.full((1, self.subword_dict_size), self.logzero))
                return (clm_state, wlm_state, wlm_logprobs, None, log_y,
                        0.0), log_y

            clm_state, z_clm = self.subwordlm(clm_state, x)
            log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight

        # apply word-level probabilies for <space> and <eos> labels
        if xi != self.space:
            if new_node is not None and new_node[
                    1] >= 0:  # if new node is word end
                wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob
            else:
                wlm_logprob = wlm_logprobs[:, self.
                                           word_unk] + self.log_oov_penalty
            log_y[:, self.space] = wlm_logprob
            log_y[:, self.eos] = wlm_logprob
        else:
            log_y[:, self.space] = self.logzero
            log_y[:, self.eos] = self.logzero

        return (
            (clm_state, wlm_state, wlm_logprobs, new_node, log_y,
             float(clm_logprob)),
            log_y,
        )
Beispiel #4
0
    def score(self, hyp, cache, init_tensor=None):
        """Forward one step.

        Args:
            hyp (dataclass): hypothesis
            cache (dict): states cache

        Returns:
            y (torch.Tensor): decoder outputs (1, dec_dim)
            state (tuple): decoder states
                ([L x (1, dec_dim)], [L x (1, dec_dim)]),
            (torch.Tensor): token id for LM (1)

        """
        vy = to_device(self, torch.full((1, 1), hyp.yseq[-1], dtype=torch.long))

        str_yseq = "".join([str(x) for x in hyp.yseq])

        if str_yseq in cache:
            y, state = cache[str_yseq]
        else:
            ey = self.embed(vy)

            y, state = self.rnn_forward(ey[0], hyp.dec_state)
            cache[str_yseq] = (y, state)

        return y, state, vy[0]
 def final(self, state):
     wlm_state, cumsum_probs, node = state
     if node is not None and node[1] >= 0:  # check if the node is word end
         w = to_device(cumsum_probs, torch.LongTensor([node[1]]))
     else:  # this node is not a word end, which means <unk>
         w = self.var_word_unk
     wlm_state, z_wlm = self.wordlm(wlm_state, w)
     return float(F.log_softmax(z_wlm, dim=1)[:, self.word_eos])
Beispiel #6
0
    def score(self, hyp, cache, init_tensor=None):
        """Forward one step.

        Args:
            hyp (dataclass): hypothesis
            cache (dict): states cache

        Returns:
            y (torch.Tensor): decoder outputs (1, dec_dim)
            (list): decoder and attention states
                [L x (1, max_len, dec_dim)]
            lm_tokens (torch.Tensor): token id for LM (1)

        """
        tgt = to_device(self, torch.tensor(hyp.yseq).unsqueeze(0))
        lm_tokens = tgt[:, -1]

        str_yseq = "".join([str(x) for x in hyp.yseq])

        if str_yseq in cache:
            y, new_state = cache[str_yseq]
        else:
            tgt_mask = to_device(self, subsequent_mask(len(hyp.yseq)).unsqueeze(0))

            state = check_state(hyp.dec_state, (tgt.size(1) - 1), self.blank)

            tgt = self.embed(tgt)

            new_state = []
            for s, decoder in zip(state, self.decoders):
                tgt, tgt_mask = decoder(tgt, tgt_mask, cache=s)
                new_state.append(tgt)

            y = self.after_norm(tgt[:, -1])

            cache[str_yseq] = (y, new_state)

        return y, new_state, lm_tokens
    def forward(self, xs_pad, ilens, prev_states=None):
        """Encoder forward

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...)
        :return: batch of hidden state sequences (B, Tmax, eprojs)
        :rtype: torch.Tensor
        """
        if prev_states is None:
            prev_states = [None] * len(self.enc)
        assert len(prev_states) == len(self.enc)

        current_states = []
        for module, prev_state in zip(self.enc, prev_states):
            xs_pad, ilens, states = module(xs_pad,
                                           ilens,
                                           prev_state=prev_state)
            current_states.append(states)

        # make mask to remove bias value in padded part
        mask = to_device(xs_pad, make_pad_mask(ilens).unsqueeze(-1))

        return xs_pad.masked_fill(mask, 0.0), ilens, current_states
Beispiel #8
0
    def batch_score(self, hyps, batch_states, cache, init_tensor=None):
        """Forward batch one step.

        Args:
            hyps (list): batch of hypotheses
            batch_states (list): decoder states
                [L x (B, max_len, dec_dim)]
            cache (dict): states cache

        Returns:
            batch_y (torch.Tensor): decoder output (B, dec_dim)
            batch_states (list): decoder states
                [L x (B, max_len, dec_dim)]
            lm_tokens (torch.Tensor): batch of token ids for LM (B)

        """
        final_batch = len(hyps)

        tokens = []
        process = []
        done = [None for _ in range(final_batch)]

        for i, hyp in enumerate(hyps):
            str_yseq = "".join([str(x) for x in hyp.yseq])

            if str_yseq in cache:
                done[i] = (*cache[str_yseq], hyp.yseq)
            else:
                tokens.append(hyp.yseq)
                process.append((str_yseq, hyp.dec_state, hyp.yseq))

        if process:
            batch = len(tokens)

            tokens = pad_sequence(tokens, self.blank)
            b_tokens = to_device(self, torch.LongTensor(tokens).view(batch, -1))

            tgt_mask = to_device(
                self,
                subsequent_mask(b_tokens.size(-1)).unsqueeze(0).expand(batch, -1, -1),
            )

            dec_state = self.init_state()

            dec_state = self.create_batch_states(
                dec_state,
                [p[1] for p in process],
                tokens,
            )

            tgt = self.embed(b_tokens)

            next_state = []
            for s, decoder in zip(dec_state, self.decoders):
                tgt, tgt_mask = decoder(tgt, tgt_mask, cache=s)
                next_state.append(tgt)

            tgt = self.after_norm(tgt[:, -1])

        j = 0
        for i in range(final_batch):
            if done[i] is None:
                new_state = self.select_state(next_state, j)

                done[i] = (tgt[j], new_state, process[j][2])
                cache[process[j][0]] = (tgt[j], new_state)

                j += 1

        batch_states = self.create_batch_states(
            batch_states, [d[1] for d in done], [d[2] for d in done]
        )
        batch_y = torch.stack([d[0] for d in done])

        lm_tokens = to_device(
            self, torch.LongTensor([h.yseq[-1] for h in hyps]).view(final_batch)
        )

        return batch_y, batch_states, lm_tokens
    def forward(self, state, x):
        # update state with input label x
        if state is None:  # make initial states and cumlative probability vector
            self.var_word_eos = to_device(x, self.var_word_eos)
            self.var_word_unk = to_device(x, self.var_word_eos)
            self.zero_tensor = to_device(x, self.zero_tensor)
            wlm_state, z_wlm = self.wordlm(None, self.var_word_eos)
            cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1)
            new_node = self.lexroot
            xi = self.space
        else:
            wlm_state, cumsum_probs, node = state
            xi = int(x)
            if xi == self.space:  # inter-word transition
                if node is not None and node[
                        1] >= 0:  # check if the node is word end
                    w = to_device(x, torch.LongTensor([node[1]]))
                else:  # this node is not a word end, which means <unk>
                    w = self.var_word_unk
                # update wordlm state and cumlative probability vector
                wlm_state, z_wlm = self.wordlm(wlm_state, w)
                cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1)
                new_node = self.lexroot  # move to the tree root
            elif node is not None and xi in node[0]:  # intra-word transition
                new_node = node[0][xi]
            elif self.open_vocab:  # if no path in the tree, enter open-vocabulary mode
                new_node = None
            else:  # if open_vocab flag is disabled, return 0 probabilities
                log_y = to_device(
                    x, torch.full((1, self.subword_dict_size), self.logzero))
                return (wlm_state, None, None), log_y

        if new_node is not None:
            succ, wid, wids = new_node
            # compute parent node probability
            sum_prob = ((cumsum_probs[:, wids[1]] - cumsum_probs[:, wids[0]])
                        if wids is not None else 1.0)
            if sum_prob < self.zero:
                log_y = to_device(
                    x, torch.full((1, self.subword_dict_size), self.logzero))
                return (wlm_state, cumsum_probs, new_node), log_y
            # set <unk> probability as a default value
            unk_prob = (cumsum_probs[:, self.word_unk] -
                        cumsum_probs[:, self.word_unk - 1])
            y = to_device(
                x,
                torch.full((1, self.subword_dict_size),
                           float(unk_prob) * self.oov_penalty),
            )
            # compute transition probabilities to child nodes
            for cid, nd in succ.items():
                y[:, cid] = (cumsum_probs[:, nd[2][1]] -
                             cumsum_probs[:, nd[2][0]]) / sum_prob
            # apply word-level probabilies for <space> and <eos> labels
            if wid >= 0:
                wlm_prob = (cumsum_probs[:, wid] -
                            cumsum_probs[:, wid - 1]) / sum_prob
                y[:, self.space] = wlm_prob
                y[:, self.eos] = wlm_prob
            elif xi == self.space:
                y[:, self.space] = self.zero
                y[:, self.eos] = self.zero
            log_y = torch.log(torch.max(
                y, self.zero_tensor))  # clip to avoid log(0)
        else:  # if no path in the tree, transition probability is one
            log_y = to_device(x, torch.zeros(1, self.subword_dict_size))
        return (wlm_state, cumsum_probs, new_node), log_y
Beispiel #10
0
    def batch_score(self, hyps, batch_states, cache, init_tensor=None):
        """Forward batch one step.

        Args:
            hyps (list): batch of hypotheses
            batch_states (tuple): batch of decoder states
                ([L x (B, dec_dim)], [L x (B, dec_dim)])
            cache (dict): states cache

        Returns:
            batch_y (torch.Tensor): decoder output (B, dec_dim)
            batch_states (tuple): batch of decoder states
                ([L x (B, dec_dim)], [L x (B, dec_dim)])
            lm_tokens (torch.Tensor): batch of token ids for LM (B)

        """
        final_batch = len(hyps)

        tokens = []
        process = []
        done = [None for _ in range(final_batch)]

        for i, hyp in enumerate(hyps):
            str_yseq = "".join([str(x) for x in hyp.yseq])

            if str_yseq in cache:
                done[i] = cache[str_yseq]
            else:
                tokens.append(hyp.yseq[-1])
                process.append((str_yseq, hyp.dec_state))

        if process:
            batch = len(process)

            tokens = to_device(self, torch.LongTensor(tokens).view(batch))

            dec_state = self.init_state(torch.zeros((batch, self.dunits)))
            dec_state = self.create_batch_states(dec_state, [p[1] for p in process])

            ey = self.embed(tokens)

            y, dec_state = self.rnn_forward(ey, dec_state)

        j = 0
        for i in range(final_batch):
            if done[i] is None:
                new_state = self.select_state(dec_state, j)

                done[i] = (y[j], new_state)
                cache[process[j][0]] = (y[j], new_state)

                j += 1

        batch_states = self.create_batch_states(batch_states, [d[1] for d in done])
        batch_y = torch.stack([d[0] for d in done])

        lm_tokens = to_device(
            self, torch.LongTensor([h.yseq[-1] for h in hyps]).view(final_batch)
        )

        return batch_y, batch_states, lm_tokens