コード例 #1
0
    def decode(self, emissions):
        batch_size, time_length, num_classes = emissions.size()

        if self.asg_transitions is None:
            transitions = torch.FloatTensor(
                num_classes,
                num_classes,
            ).zero_()
        else:
            transitions = torch.FloatTensor(self.asg_transitions).view(
                num_classes,
                num_classes,
            )

        viterbi_path = torch.IntTensor(batch_size, time_length)
        workspace = torch.ByteTensor(
            CpuViterbiPath.get_workspace_size(
                batch_size,
                time_length,
                num_classes,
            ))
        CpuViterbiPath.compute(
            batch_size,
            time_length,
            num_classes,
            get_data_ptr_as_bytes(emissions),
            get_data_ptr_as_bytes(transitions),
            get_data_ptr_as_bytes(viterbi_path),
            get_data_ptr_as_bytes(workspace),
        )
        return [[{
            "tokens": self.get_tokens(viterbi_path[b].tolist()),
            "score": 0
        }] for b in range(batch_size)]
コード例 #2
0
    def decode(self, emissions):
        B, T, N = emissions.size()
        hypos = list()

        if self.asg_transitions is None:
            transitions = torch.FloatTensor(N, N).zero_()
        else:
            transitions = torch.FloatTensor(self.asg_transitions).view(N, N)

        viterbi_path = torch.IntTensor(B, T)
        workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(
            B, T, N))
        CpuViterbiPath.compute(
            B,
            T,
            N,
            get_data_ptr_as_bytes(emissions),
            get_data_ptr_as_bytes(transitions),
            get_data_ptr_as_bytes(viterbi_path),
            get_data_ptr_as_bytes(workspace),
        )
        return [[{
            "tokens": self.get_tokens(viterbi_path[b].tolist()),
            "score": 0
        }] for b in range(B)]
コード例 #3
0
    def decode(self, emissions):                                        # TODO: figure out decode fn (should be tokens to text)
        batch_size, time_length, num_classes = emissions.size()         # batch_size = B, time_length = T, num_classes = C (num of tokens/input size)?

        if self.asg_transitions is None:                                # default None  # asg_transitions: probabilities of each letter pair in corpus; asg = auto segmentation
            transitions = torch.FloatTensor(                                # transitions: torch.FloatTensor size [108, 108], filled w/ zeros
                num_classes,
                num_classes,
            ).zero_()
        else:
            transitions = torch.FloatTensor(self.asg_transitions).view( #
                num_classes,
                num_classes,
            )

        viterbi_path = torch.IntTensor(batch_size, time_length)         # size [1, 95]   # stores results that decoder returns?
        workspace = torch.ByteTensor(
            CpuViterbiPath.get_workspace_size(                          # get_workspace_size: allocates contiguous memory space for arrays the Viterbi decoder uses
                batch_size,
                time_length,
                num_classes,
            ))
        CpuViterbiPath.compute(                                         # runs Viterbi algorithm and returns most likely token sequence; pass in tensor pointers to the C++ method that implements Viterbi algorithm
            batch_size,
            time_length,
            num_classes,
            get_data_ptr_as_bytes(emissions),                           # gets the pointers of the tensors we made
            get_data_ptr_as_bytes(transitions),
            get_data_ptr_as_bytes(viterbi_path),
            get_data_ptr_as_bytes(workspace),
        )
        return [[{                                                      # for each batch:
            "tokens": self.get_tokens(viterbi_path[b].tolist()),            # "tokens": tensor([ 8, 11, 14, 11, 10,  5,  8, 48, 10, 32,  6, 37,  7, 11, 10,  5, 32, 12, 26, 22,  6, 18, 27,  8, 13,  5]), tokens.size(): torch.Size([26])
            "score": 0                                                      # tokens: normalized tokens from Viterbi algorithm?
        }] for b in range(batch_size)]