def decode( self, emissions: torch.FloatTensor, ) -> List[List[Dict[str, torch.LongTensor]]]: B, T, N = emissions.size() if self.asgtransitions is None: transitions = torch.FloatTensor(N, N).zero_() else: transitions = torch.FloatTensor(self.asgtransitions).view(N, N) viterbi_path = torch.IntTensor(B, T) workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size( B, T, N)) CpuViterbiPath.compute( B, T, N, get_data_ptr_as_bytes(emissions), get_data_ptr_as_bytes(transitions), get_data_ptr_as_bytes(viterbi_path), get_data_ptr_as_bytes(workspace), ) return [[{ "tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0 }] for b in range(B)]
def decode(self, emissions): B, T, N = emissions.size() hypos = [] if self.asg_transitions is None: transitions = torch.FloatTensor(N, N).zero_() else: transitions = torch.FloatTensor(self.asg_transitions).view(N, N) viterbi_path = torch.IntTensor(B, T) workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size( B, T, N)) CpuViterbiPath.compute( B, T, N, get_data_ptr_as_bytes(emissions), get_data_ptr_as_bytes(transitions), get_data_ptr_as_bytes(viterbi_path), get_data_ptr_as_bytes(workspace), ) logging.info("decoder input length: " + str(emissions.shape[1])) returns = [[{ "tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0 }] for b in range(B)] if self.token_list is not None: logging.info( "best hypo: " + "".join([self.token_list[x] for x in returns[0][0]["tokens"]]) + "\n") return returns