def forward_unlabeled(self, all_scores: torch.Tensor, word_seq_lens: torch.Tensor) -> torch.Tensor:
        """
        Calculate the scores with the forward algorithm. Basically calculating the normalization term
        :param all_scores: (batch_size x max_seq_len x num_labels x num_labels) from (lstm scores + transition scores).
        :param word_seq_lens: (batch_size)
        :return: The score for all the possible structures.
        """
        batch_size = all_scores.size(0)
        seq_len = all_scores.size(1)
        dev_num = all_scores.get_device()
        curr_dev = torch.device(f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu")
        alpha = torch.zeros(batch_size, seq_len, self.label_size, device=curr_dev)

        alpha[:, 0, :] = all_scores[:, 0,  self.start_idx, :] ## the first position of all labels = (the transition from start - > all labels) + current emission.

        for word_idx in range(1, seq_len):
            ## batch_size, self.label_size, self.label_size
            before_log_sum_exp = alpha[:, word_idx-1, :].view(batch_size, self.label_size, 1).expand(batch_size, self.label_size, self.label_size) + all_scores[:, word_idx, :, :]
            alpha[:, word_idx, :] = log_sum_exp_pytorch(before_log_sum_exp)

        ### batch_size x label_size
        last_alpha = torch.gather(alpha, 1, word_seq_lens.view(batch_size, 1, 1).expand(batch_size, 1, self.label_size)-1).view(batch_size, self.label_size)
        last_alpha += self.transition[:, self.end_idx].view(1, self.label_size).expand(batch_size, self.label_size)
        last_alpha = log_sum_exp_pytorch(last_alpha.view(batch_size, self.label_size, 1)).view(batch_size)

        ## final score for the unlabeled network in this batch, with size: 1
        return torch.sum(last_alpha)
Beispiel #2
0
    def backward(self, lstm_scores: torch.Tensor,
                 word_seq_lens: torch.Tensor) -> torch.Tensor:
        """
        Backward algorithm. A benchmark implementation which is ready to use.
        :param lstm_scores: shape: (batch_size, sent_len, label_size) NOTE: the score from LSTMs, not `all_scores` (which add up the transtiion)
        :param word_seq_lens: shape: (batch_size,)
        :return: Backward variable
        """
        batch_size = lstm_scores.size(0)
        seq_len = lstm_scores.size(1)
        dev_num = lstm_scores.get_device()
        curr_dev = torch.device(
            f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu")
        beta = torch.zeros(batch_size,
                           seq_len,
                           self.label_size,
                           device=curr_dev)

        ## reverse the view of computing the score. we look from behind
        rev_score = self.transition.transpose(0, 1).view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \
                    lstm_scores.view(batch_size, seq_len, 1, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size)

        ## The code below, reverse the score from [0 -> length]  to [length -> 0].  (NOTE: we need to avoid reversing the padding)
        perm_idx = torch.zeros(batch_size, seq_len, device=curr_dev)
        for batch_idx in range(batch_size):
            perm_idx[batch_idx][:word_seq_lens[batch_idx]] = torch.range(
                word_seq_lens[batch_idx] - 1, 0, -1)
        perm_idx = perm_idx.long()
        for i, length in enumerate(word_seq_lens):
            rev_score[i, :length] = rev_score[i, :length][perm_idx[i, :length]]

        ## backward operation
        beta[:, 0, :] = rev_score[:, 0, self.end_idx, :]
        for word_idx in range(1, seq_len):
            before_log_sum_exp = beta[:, word_idx - 1, :].view(
                batch_size, self.label_size,
                1).expand(batch_size, self.label_size,
                          self.label_size) + rev_score[:, word_idx, :, :]
            beta[:, word_idx, :] = log_sum_exp_pytorch(before_log_sum_exp)

        ## Following code is used to check the backward beta implementation
        last_beta = torch.gather(
            beta, 1,
            word_seq_lens.view(batch_size, 1, 1).expand(
                batch_size, 1, self.label_size) - 1).view(
                    batch_size, self.label_size)
        last_beta += self.transition.transpose(0, 1)[:, self.start_idx].view(
            1, self.label_size).expand(batch_size, self.label_size)
        last_beta = log_sum_exp_pytorch(
            last_beta.view(batch_size, self.label_size, 1)).view(batch_size)

        # This part if optionally, if you only use `last_beta`.
        # Otherwise, you need this to reverse back if you also need to use beta
        for i, length in enumerate(word_seq_lens):
            beta[i, :length] = beta[i, :length][perm_idx[i, :length]]

        return torch.sum(last_beta)
    def forward_backward(self, lstm_scores: torch.Tensor, word_seq_lens: torch.Tensor) -> torch.Tensor:
        """
        Note: This function is not used unless you want to compute the marginal probability
        Forward-backward algorithm to compute the marginal probability (in log space)
        Basically, we follow the `backward` algorithm to obtain the backward scores.
        :param lstm_scores:   shape: (batch_size, sent_len, label_size) NOTE: the score from LSTMs, not `all_scores` (which add up the transtiion)
        :param word_seq_lens: shape: (batch_size,)
        :return: Marginal score. If you want probability, you need to use `torch.exp` to convert it into probability
        """
        batch_size = lstm_scores.size(0)
        seq_len = lstm_scores.size(1)
        dev_num = lstm_scores.get_device()
        curr_dev = torch.device(f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu")
        alpha = torch.zeros(batch_size, seq_len, self.label_size, device=curr_dev)
        beta = torch.zeros(batch_size, seq_len, self.label_size, device=curr_dev)

        scores = self.transition.view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \
                 lstm_scores.view(batch_size, seq_len, 1, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size)
        ## reverse the view of computing the score. we look from behind
        rev_score = self.transition.transpose(0, 1).view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \
                    lstm_scores.view(batch_size, seq_len, 1, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size)

        perm_idx = torch.zeros(batch_size, seq_len, device=curr_dev)
        for batch_idx in range(batch_size):
            perm_idx[batch_idx][:word_seq_lens[batch_idx]] = torch.range(word_seq_lens[batch_idx] - 1, 0, -1)
        perm_idx = perm_idx.long()
        for i, length in enumerate(word_seq_lens):
            rev_score[i, :length] = rev_score[i, :length][perm_idx[i, :length]]

        alpha[:, 0, :] = scores[:, 0, self.start_idx, :]  ## the first position of all labels = (the transition from start - > all labels) + current emission.
        beta[:, 0, :] = rev_score[:, 0, self.end_idx, :]
        for word_idx in range(1, seq_len):
            before_log_sum_exp = alpha[:, word_idx - 1, :].view(batch_size, self.label_size, 1).expand(batch_size, self.label_size, self.label_size) + scores[ :, word_idx, :, :]
            alpha[:, word_idx, :] = log_sum_exp_pytorch(before_log_sum_exp)

            before_log_sum_exp = beta[:, word_idx - 1, :].view(batch_size, self.label_size, 1).expand(batch_size, self.label_size, self.label_size) + rev_score[:, word_idx, :, :]
            beta[:, word_idx, :] = log_sum_exp_pytorch(before_log_sum_exp)

        ### batch_size x label_size
        last_alpha = torch.gather(alpha, 1, word_seq_lens.view(batch_size, 1, 1).expand(batch_size, 1, self.label_size) - 1).view( batch_size, self.label_size)
        last_alpha += self.transition[:, self.end_idx].view(1, self.label_size).expand(batch_size, self.label_size)
        last_alpha = log_sum_exp_pytorch(last_alpha.view(batch_size, self.label_size, 1)).view(batch_size, 1, 1).expand(batch_size, seq_len, self.label_size)

        ## Because we need to use the beta variable later, we need to reverse back
        for i, length in enumerate(word_seq_lens):
            beta[i, :length] = beta[i, :length][perm_idx[i, :length]]

        # `alpha + beta - last_alpha` is the standard way to obtain the marginal
        # However, we have two emission scores overlap at each position, thus, we need to subtract one emission score
        return alpha + beta - last_alpha - lstm_scores
    def forward_unlabeled(self, all_scores: torch.Tensor,
                          word_seq_lens: torch.Tensor) -> torch.Tensor:
        """
        Calculate the scores with the forward algorithm. Basically calculating the normalization term
        :param all_scores: (batch_size x max_seq_len x num_labels x num_labels) from (lstm scores + transition scores).
        :param word_seq_lens: (batch_size)
        :return: The score for all the possible structures.
        """
        batch_size = all_scores.size(0)
        seq_len = all_scores.size(1)
        dev_num = all_scores.get_device()
        curr_dev = torch.device(
            f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu")
        depth = math.ceil(math.log2(seq_len))

        padded_length = int(math.pow(2, depth))
        sweep_score = torch.zeros(batch_size,
                                  padded_length,
                                  depth + 1,
                                  self.label_size,
                                  self.label_size,
                                  device=curr_dev)

        sweep_score[:, :seq_len, 0, :, :] = all_scores

        step_size = 2
        f_start = 0
        b_start = 1
        for d in range(depth):
            ##correct
            forward_score = sweep_score[:, f_start::step_size,
                                        d, :, :].unsqueeze(-1).expand(
                                            batch_size,
                                            padded_length // step_size,
                                            self.label_size, self.label_size,
                                            self.label_size)
            ##checking
            backward_score = sweep_score[:, b_start::step_size,
                                         d, :, :].unsqueeze(2).expand(
                                             batch_size,
                                             padded_length // step_size,
                                             self.label_size, self.label_size,
                                             self.label_size)
            sweep_score[:, b_start::step_size, d + 1, :, :] = torch.logsumexp(
                forward_score + backward_score, dim=-2)
            # sweep_score[:, b_start::step_size, d + 1, :, :] = torch.sum(forward_score + backward_score, dim=-2)
            f_start = b_start
            b_start = b_start + step_size
            step_size *= 2
        # print(f"depth is {depth}, step_size: {step_size}")

        ##doing down_sweep
        step_size = step_size // 2
        sweep_score[:, -1, -1, :, :] = 0  # -float("Inf")
        # sweep_score[:, -1, -1, :, :] = 0 #for sum
        f_start = padded_length // 2 - 1
        b_start = padded_length - 1
        #log sum exp
        first_mask = torch.full(
            [self.label_size, self.label_size, self.label_size],
            fill_value=-float("Inf"),
            device=curr_dev)
        # sum
        # first_mask = torch.full([self.label_size, self.label_size, self.label_size], fill_value=0, device=curr_dev)
        idxs = torch.arange(self.label_size, device=curr_dev)
        interleave_idxs = idxs.repeat_interleave(self.label_size)
        first_mask[
            interleave_idxs, interleave_idxs,
            idxs.repeat(self.label_size)] = 0  # log sum exp is 0. to pass over
        # first_mask[interleave_idxs, interleave_idxs, idxs.repeat(self.label_size)] = 1 # sum is 1. to pass through
        first_mask = first_mask.unsqueeze(0).unsqueeze(0).expand(
            batch_size, 1, self.label_size, self.label_size, self.label_size)
        ## for log sum exp
        zero_mask = torch.zeros(self.label_size, device=curr_dev).unsqueeze(
            0).unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(
                batch_size, 1, self.label_size, self.label_size,
                self.label_size)
        ## for sum
        # zero_mask = torch.ones(self.label_size, device=curr_dev).unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(
        #     batch_size, 1, self.label_size, self.label_size, self.label_size
        # )
        for d in range(depth - 1, -1, -1):
            length = padded_length // step_size
            ## backward score, calculate temporary
            temporary = sweep_score[:, f_start::step_size, d, :, :].clone()
            temporary = temporary.unsqueeze(2).expand(batch_size, length,
                                                      self.label_size,
                                                      self.label_size,
                                                      self.label_size)

            sweep_score[:, f_start::step_size,
                        d, :, :] = sweep_score[:, b_start::step_size,
                                               d + 1, :, :].clone()

            ##forward_score
            forward_score = sweep_score[:, b_start::step_size,
                                        d + 1, :, :].unsqueeze(-1).expand(
                                            batch_size, length,
                                            self.label_size, self.label_size,
                                            self.label_size)
            curr_zero_mask = zero_mask.expand(batch_size, length - 1,
                                              self.label_size, self.label_size,
                                              self.label_size)
            mask = torch.cat([first_mask, curr_zero_mask], dim=1)
            # calculate backward originate score
            sweep_score[:, b_start::step_size, d, :, :] = torch.logsumexp(
                forward_score + temporary + mask, dim=-2)
            # sweep_score[:, b_start::step_size, d, :, :] = torch.sum(forward_score + mask * temporary, dim=-2) #

            b_start = f_start
            step_size = step_size // 2
            f_start = f_start - step_size // 2

        curr_zero_mask = zero_mask.expand(batch_size, seq_len - 1,
                                          self.label_size, self.label_size,
                                          self.label_size)
        mask = torch.cat([first_mask, curr_zero_mask], dim=1)
        sweep_score[:, :seq_len, 0, :, :] = torch.logsumexp(
            sweep_score[:, :seq_len, 0, :, :].unsqueeze(-1).expand(
                batch_size, seq_len, self.label_size, self.label_size,
                self.label_size) + all_scores.unsqueeze(2).expand(
                    batch_size, seq_len, self.label_size, self.label_size,
                    self.label_size) + mask,
            dim=-2)
        # sweep_score[:, :seq_len, 0, :, :] = torch.sum(
        #     sweep_score[:, :seq_len, 0, :, :].unsqueeze(-1).expand(batch_size, seq_len, self.label_size, self.label_size, self.label_size) +
        #     all_scores.unsqueeze(2).expand(batch_size, seq_len, self.label_size, self.label_size, self.label_size) * mask,
        #     dim=-2
        # )

        ### batch_size x label_size
        last_alpha = torch.gather(
            sweep_score[:, :, 0, self.start_idx, :], 1,
            word_seq_lens.view(batch_size, 1, 1).expand(
                batch_size, 1, self.label_size) - 1).view(
                    batch_size, self.label_size)
        last_alpha += self.transition[:, self.end_idx].view(
            1, self.label_size).expand(batch_size, self.label_size)
        last_alpha = log_sum_exp_pytorch(
            last_alpha.view(batch_size, self.label_size, 1)).view(batch_size)

        ## final score for the unlabeled network in this batch, with size: 1
        return torch.sum(last_alpha)