def forward_unlabeled(self, all_scores: torch.Tensor, word_seq_lens: torch.Tensor) -> torch.Tensor: """ Calculate the scores with the forward algorithm. Basically calculating the normalization term :param all_scores: (batch_size x max_seq_len x num_labels x num_labels) from (lstm scores + transition scores). :param word_seq_lens: (batch_size) :return: The score for all the possible structures. """ batch_size = all_scores.size(0) seq_len = all_scores.size(1) dev_num = all_scores.get_device() curr_dev = torch.device(f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu") alpha = torch.zeros(batch_size, seq_len, self.label_size, device=curr_dev) alpha[:, 0, :] = all_scores[:, 0, self.start_idx, :] ## the first position of all labels = (the transition from start - > all labels) + current emission. for word_idx in range(1, seq_len): ## batch_size, self.label_size, self.label_size before_log_sum_exp = alpha[:, word_idx-1, :].view(batch_size, self.label_size, 1).expand(batch_size, self.label_size, self.label_size) + all_scores[:, word_idx, :, :] alpha[:, word_idx, :] = log_sum_exp_pytorch(before_log_sum_exp) ### batch_size x label_size last_alpha = torch.gather(alpha, 1, word_seq_lens.view(batch_size, 1, 1).expand(batch_size, 1, self.label_size)-1).view(batch_size, self.label_size) last_alpha += self.transition[:, self.end_idx].view(1, self.label_size).expand(batch_size, self.label_size) last_alpha = log_sum_exp_pytorch(last_alpha.view(batch_size, self.label_size, 1)).view(batch_size) ## final score for the unlabeled network in this batch, with size: 1 return torch.sum(last_alpha)
def backward(self, lstm_scores: torch.Tensor, word_seq_lens: torch.Tensor) -> torch.Tensor: """ Backward algorithm. A benchmark implementation which is ready to use. :param lstm_scores: shape: (batch_size, sent_len, label_size) NOTE: the score from LSTMs, not `all_scores` (which add up the transtiion) :param word_seq_lens: shape: (batch_size,) :return: Backward variable """ batch_size = lstm_scores.size(0) seq_len = lstm_scores.size(1) dev_num = lstm_scores.get_device() curr_dev = torch.device( f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu") beta = torch.zeros(batch_size, seq_len, self.label_size, device=curr_dev) ## reverse the view of computing the score. we look from behind rev_score = self.transition.transpose(0, 1).view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \ lstm_scores.view(batch_size, seq_len, 1, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) ## The code below, reverse the score from [0 -> length] to [length -> 0]. (NOTE: we need to avoid reversing the padding) perm_idx = torch.zeros(batch_size, seq_len, device=curr_dev) for batch_idx in range(batch_size): perm_idx[batch_idx][:word_seq_lens[batch_idx]] = torch.range( word_seq_lens[batch_idx] - 1, 0, -1) perm_idx = perm_idx.long() for i, length in enumerate(word_seq_lens): rev_score[i, :length] = rev_score[i, :length][perm_idx[i, :length]] ## backward operation beta[:, 0, :] = rev_score[:, 0, self.end_idx, :] for word_idx in range(1, seq_len): before_log_sum_exp = beta[:, word_idx - 1, :].view( batch_size, self.label_size, 1).expand(batch_size, self.label_size, self.label_size) + rev_score[:, word_idx, :, :] beta[:, word_idx, :] = log_sum_exp_pytorch(before_log_sum_exp) ## Following code is used to check the backward beta implementation last_beta = torch.gather( beta, 1, word_seq_lens.view(batch_size, 1, 1).expand( batch_size, 1, self.label_size) - 1).view( batch_size, self.label_size) last_beta += self.transition.transpose(0, 1)[:, self.start_idx].view( 1, self.label_size).expand(batch_size, self.label_size) last_beta = log_sum_exp_pytorch( last_beta.view(batch_size, self.label_size, 1)).view(batch_size) # This part if optionally, if you only use `last_beta`. # Otherwise, you need this to reverse back if you also need to use beta for i, length in enumerate(word_seq_lens): beta[i, :length] = beta[i, :length][perm_idx[i, :length]] return torch.sum(last_beta)
def forward_backward(self, lstm_scores: torch.Tensor, word_seq_lens: torch.Tensor) -> torch.Tensor: """ Note: This function is not used unless you want to compute the marginal probability Forward-backward algorithm to compute the marginal probability (in log space) Basically, we follow the `backward` algorithm to obtain the backward scores. :param lstm_scores: shape: (batch_size, sent_len, label_size) NOTE: the score from LSTMs, not `all_scores` (which add up the transtiion) :param word_seq_lens: shape: (batch_size,) :return: Marginal score. If you want probability, you need to use `torch.exp` to convert it into probability """ batch_size = lstm_scores.size(0) seq_len = lstm_scores.size(1) dev_num = lstm_scores.get_device() curr_dev = torch.device(f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu") alpha = torch.zeros(batch_size, seq_len, self.label_size, device=curr_dev) beta = torch.zeros(batch_size, seq_len, self.label_size, device=curr_dev) scores = self.transition.view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \ lstm_scores.view(batch_size, seq_len, 1, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) ## reverse the view of computing the score. we look from behind rev_score = self.transition.transpose(0, 1).view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \ lstm_scores.view(batch_size, seq_len, 1, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) perm_idx = torch.zeros(batch_size, seq_len, device=curr_dev) for batch_idx in range(batch_size): perm_idx[batch_idx][:word_seq_lens[batch_idx]] = torch.range(word_seq_lens[batch_idx] - 1, 0, -1) perm_idx = perm_idx.long() for i, length in enumerate(word_seq_lens): rev_score[i, :length] = rev_score[i, :length][perm_idx[i, :length]] alpha[:, 0, :] = scores[:, 0, self.start_idx, :] ## the first position of all labels = (the transition from start - > all labels) + current emission. beta[:, 0, :] = rev_score[:, 0, self.end_idx, :] for word_idx in range(1, seq_len): before_log_sum_exp = alpha[:, word_idx - 1, :].view(batch_size, self.label_size, 1).expand(batch_size, self.label_size, self.label_size) + scores[ :, word_idx, :, :] alpha[:, word_idx, :] = log_sum_exp_pytorch(before_log_sum_exp) before_log_sum_exp = beta[:, word_idx - 1, :].view(batch_size, self.label_size, 1).expand(batch_size, self.label_size, self.label_size) + rev_score[:, word_idx, :, :] beta[:, word_idx, :] = log_sum_exp_pytorch(before_log_sum_exp) ### batch_size x label_size last_alpha = torch.gather(alpha, 1, word_seq_lens.view(batch_size, 1, 1).expand(batch_size, 1, self.label_size) - 1).view( batch_size, self.label_size) last_alpha += self.transition[:, self.end_idx].view(1, self.label_size).expand(batch_size, self.label_size) last_alpha = log_sum_exp_pytorch(last_alpha.view(batch_size, self.label_size, 1)).view(batch_size, 1, 1).expand(batch_size, seq_len, self.label_size) ## Because we need to use the beta variable later, we need to reverse back for i, length in enumerate(word_seq_lens): beta[i, :length] = beta[i, :length][perm_idx[i, :length]] # `alpha + beta - last_alpha` is the standard way to obtain the marginal # However, we have two emission scores overlap at each position, thus, we need to subtract one emission score return alpha + beta - last_alpha - lstm_scores
def forward_unlabeled(self, all_scores: torch.Tensor, word_seq_lens: torch.Tensor) -> torch.Tensor: """ Calculate the scores with the forward algorithm. Basically calculating the normalization term :param all_scores: (batch_size x max_seq_len x num_labels x num_labels) from (lstm scores + transition scores). :param word_seq_lens: (batch_size) :return: The score for all the possible structures. """ batch_size = all_scores.size(0) seq_len = all_scores.size(1) dev_num = all_scores.get_device() curr_dev = torch.device( f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu") depth = math.ceil(math.log2(seq_len)) padded_length = int(math.pow(2, depth)) sweep_score = torch.zeros(batch_size, padded_length, depth + 1, self.label_size, self.label_size, device=curr_dev) sweep_score[:, :seq_len, 0, :, :] = all_scores step_size = 2 f_start = 0 b_start = 1 for d in range(depth): ##correct forward_score = sweep_score[:, f_start::step_size, d, :, :].unsqueeze(-1).expand( batch_size, padded_length // step_size, self.label_size, self.label_size, self.label_size) ##checking backward_score = sweep_score[:, b_start::step_size, d, :, :].unsqueeze(2).expand( batch_size, padded_length // step_size, self.label_size, self.label_size, self.label_size) sweep_score[:, b_start::step_size, d + 1, :, :] = torch.logsumexp( forward_score + backward_score, dim=-2) # sweep_score[:, b_start::step_size, d + 1, :, :] = torch.sum(forward_score + backward_score, dim=-2) f_start = b_start b_start = b_start + step_size step_size *= 2 # print(f"depth is {depth}, step_size: {step_size}") ##doing down_sweep step_size = step_size // 2 sweep_score[:, -1, -1, :, :] = 0 # -float("Inf") # sweep_score[:, -1, -1, :, :] = 0 #for sum f_start = padded_length // 2 - 1 b_start = padded_length - 1 #log sum exp first_mask = torch.full( [self.label_size, self.label_size, self.label_size], fill_value=-float("Inf"), device=curr_dev) # sum # first_mask = torch.full([self.label_size, self.label_size, self.label_size], fill_value=0, device=curr_dev) idxs = torch.arange(self.label_size, device=curr_dev) interleave_idxs = idxs.repeat_interleave(self.label_size) first_mask[ interleave_idxs, interleave_idxs, idxs.repeat(self.label_size)] = 0 # log sum exp is 0. to pass over # first_mask[interleave_idxs, interleave_idxs, idxs.repeat(self.label_size)] = 1 # sum is 1. to pass through first_mask = first_mask.unsqueeze(0).unsqueeze(0).expand( batch_size, 1, self.label_size, self.label_size, self.label_size) ## for log sum exp zero_mask = torch.zeros(self.label_size, device=curr_dev).unsqueeze( 0).unsqueeze(0).unsqueeze(0).unsqueeze(0).expand( batch_size, 1, self.label_size, self.label_size, self.label_size) ## for sum # zero_mask = torch.ones(self.label_size, device=curr_dev).unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0).expand( # batch_size, 1, self.label_size, self.label_size, self.label_size # ) for d in range(depth - 1, -1, -1): length = padded_length // step_size ## backward score, calculate temporary temporary = sweep_score[:, f_start::step_size, d, :, :].clone() temporary = temporary.unsqueeze(2).expand(batch_size, length, self.label_size, self.label_size, self.label_size) sweep_score[:, f_start::step_size, d, :, :] = sweep_score[:, b_start::step_size, d + 1, :, :].clone() ##forward_score forward_score = sweep_score[:, b_start::step_size, d + 1, :, :].unsqueeze(-1).expand( batch_size, length, self.label_size, self.label_size, self.label_size) curr_zero_mask = zero_mask.expand(batch_size, length - 1, self.label_size, self.label_size, self.label_size) mask = torch.cat([first_mask, curr_zero_mask], dim=1) # calculate backward originate score sweep_score[:, b_start::step_size, d, :, :] = torch.logsumexp( forward_score + temporary + mask, dim=-2) # sweep_score[:, b_start::step_size, d, :, :] = torch.sum(forward_score + mask * temporary, dim=-2) # b_start = f_start step_size = step_size // 2 f_start = f_start - step_size // 2 curr_zero_mask = zero_mask.expand(batch_size, seq_len - 1, self.label_size, self.label_size, self.label_size) mask = torch.cat([first_mask, curr_zero_mask], dim=1) sweep_score[:, :seq_len, 0, :, :] = torch.logsumexp( sweep_score[:, :seq_len, 0, :, :].unsqueeze(-1).expand( batch_size, seq_len, self.label_size, self.label_size, self.label_size) + all_scores.unsqueeze(2).expand( batch_size, seq_len, self.label_size, self.label_size, self.label_size) + mask, dim=-2) # sweep_score[:, :seq_len, 0, :, :] = torch.sum( # sweep_score[:, :seq_len, 0, :, :].unsqueeze(-1).expand(batch_size, seq_len, self.label_size, self.label_size, self.label_size) + # all_scores.unsqueeze(2).expand(batch_size, seq_len, self.label_size, self.label_size, self.label_size) * mask, # dim=-2 # ) ### batch_size x label_size last_alpha = torch.gather( sweep_score[:, :, 0, self.start_idx, :], 1, word_seq_lens.view(batch_size, 1, 1).expand( batch_size, 1, self.label_size) - 1).view( batch_size, self.label_size) last_alpha += self.transition[:, self.end_idx].view( 1, self.label_size).expand(batch_size, self.label_size) last_alpha = log_sum_exp_pytorch( last_alpha.view(batch_size, self.label_size, 1)).view(batch_size) ## final score for the unlabeled network in this batch, with size: 1 return torch.sum(last_alpha)