def just_fwd(pi, trans_logprobs, bwd_obs_logprobs, constraints=None): """ pi - bsz x K bwd_obs_logprobs - L x T x bsz x K, obs probs ending at t trans_logprobs - T-1 x bsz x K x K, trans_logprobs[t] = p(q_{t+1} | q_t) """ neginf = -1e38 # -float("inf") L, seqlen, bsz, K = bwd_obs_logprobs.size() # we'll be 1-indexed for alphas and betas alph = [None] * (seqlen + 1) alph_star = [None] * (seqlen + 1) alph_star[0] = pi mask = trans_logprobs.new(L, bsz, K) bwd_maxlens = trans_logprobs.new(seqlen).fill_( L) # store max possible length generated from t bwd_maxlens[-L:].copy_(torch.arange(L, 0, -1)) bwd_maxlens = bwd_maxlens.log_().view(seqlen, 1, 1) for t in xrange(1, seqlen + 1): steps_back = min(L, t) if constraints is not None and constraints[t] is not None: tmask = mask.narrow(0, 0, steps_back).zero_() # steps_back x bsz x K -> steps_back*bsz x K tmask.view(-1, K).index_fill_(0, constraints[t], neginf) # alph_t(j) = log \sum_l p(x_{t-l+1:t}) alph*_{t-l} p(l_t) alph_terms = ( torch.stack(alph_star[t - steps_back:t]) # steps_back x bsz x K + bwd_obs_logprobs[-steps_back:, t - 1] # steps_back x bsz x K (0-idx) - bwd_maxlens[t - steps_back:t].expand(steps_back, bsz, K)) if constraints is not None and constraints[t] is not None: alph_terms = alph_terms + tmask #Variable(tmask) alph[t] = logsumexp0(alph_terms) # bsz x K if t < seqlen: # alph*_t(k) = log \sum_j alph_t(j) p(q_{t+1}=k | q_t = j) # get bsz x K x K trans logprobs, viz., p(q_{t+1}=j|i) w/ 0th dim i, 2nd dim j tps = trans_logprobs[ t - 1] # N.B. trans_logprobs[t] is p(q_{t+1}) and 0-indexed alph_t = alph[t] # bsz x K, viz, p(x, j) alph_star_terms = ( tps.transpose(0, 1) # K x bsz x K + alph_t.unsqueeze(2).expand(bsz, K, K).transpose(0, 1)) alph_star[t] = logsumexp0(alph_star_terms) return alph, alph_star
def _just_bwd(trans_logprobs, fwd_obs_logprobs, len_logprobs, constraints=None): """ fwd_obs_logprobs - L x T x bsz x K, obs probs starting at t trans_logprobs - T-1 x bsz x K x K, trans_logprobs[t] = p(q_{t+1} | q_t) """ neginf = -1e38 # -float("inf") L, seqlen, bsz, K = fwd_obs_logprobs.size() # we'll be 1-indexed for alphas and betas beta = [None] * (seqlen + 1) beta_star = [None] * (seqlen + 1) beta[seqlen] = Variable(trans_logprobs.data.new(bsz, K).zero_()) mask = trans_logprobs.data.new(L, bsz, K) for t in range(1, seqlen + 1): steps_fwd = min(L, t) len_terms = len_logprobs[min(L - 1, steps_fwd - 1)] # steps_fwd x K # print(constraints) # print(len(constraints), seqlen) if constraints is not None and constraints[seqlen - t + 1] is not None: tmask = mask.narrow(0, 0, steps_fwd).zero_() # steps_fwd x bsz x K -> steps_fwd*bsz x K tmask.view(-1, K).index_fill_(0, constraints[seqlen - t + 1], neginf) # beta*_t(k) = log \sum_l beta_{t+l}(k) p(x_{t+1:t+l}) p(l_t) beta_star_terms = ( torch.stack(beta[seqlen - t + 1:seqlen - t + 1 + steps_fwd]) # steps_fwd x bsz x K + fwd_obs_logprobs[:steps_fwd, seqlen - t] # steps_fwd x bsz x K #- math.log(steps_fwd)) # steps_fwd x bsz x K + len_terms.unsqueeze(1).expand(steps_fwd, bsz, K)) if constraints is not None and constraints[seqlen - t + 1] is not None: beta_star_terms = beta_star_terms + Variable(tmask) beta_star[seqlen - t] = logsumexp0(beta_star_terms) if seqlen - t > 0: # beta_t(j) = log \sum_k beta*_t(k) p(q_{t+1} = k | q_t=j) betastar_nt = beta_star[seqlen - t] # bsz x K # get bsz x K x K trans logprobs, viz., p(q_{t+1}=j|i) w/ 0th dim i, 2nd dim j tps = trans_logprobs[ seqlen - t - 1] # N.B. trans_logprobs[t] is p(q_{t+1}) and 0-idxed beta_terms = betastar_nt.unsqueeze(1).expand( bsz, K, K) + tps # bsz x K x K beta[seqlen - t] = logsumexp2(beta_terms) # bsz x K return beta, beta_star
def fwd_bwd(pi, trans_logprobs, bwd_obs_logprobs, fwd_obs_logprobs, just_alphas=False): """ pi - bsz x K bwd_obs_logprobs - L x T x bsz x K, obs probs ending at t fwd_obs_logprobs - L x T x bsz x K, obs probs starting at t trans_logprobs - T-1 x bsz x K x K, trans_logprobs[t] = p(q_{t+1} | q_t) """ L, seqlen, bsz, K = fwd_obs_logprobs.size() # we'll be 1-indexed for alphas and betas alph = trans_logprobs.new(seqlen + 1, bsz, K).fill_(-float("inf")) alph_star = trans_logprobs.new(seqlen + 1, bsz, K).fill_(-float("inf")) alph_star[0].copy_(pi) bwd_maxlens = trans_logprobs.new(seqlen).fill_( L) # store max possible length generated from t bwd_maxlens[-L:].copy_(torch.arange(L, 0, -1)) bwd_maxlens = bwd_maxlens.log_().view(seqlen, 1, 1) if not just_alphas: beta = trans_logprobs.new(seqlen + 1, bsz, K).fill_(-float("inf")) beta_star = trans_logprobs.new(seqlen + 1, bsz, K).fill_(-float("inf")) beta[seqlen].fill_(0) else: beta, beta_star = None, None for t in xrange(1, seqlen + 1): steps_back = min(L, t) # alph_t(j) = log \sum_l p(x_{t-l+1:t}) alph*_{t-l} p(l_t) alph_terms = ( alph_star[t - steps_back:t] # steps_back x bsz x K + bwd_obs_logprobs[-steps_back:, t - 1] ) # steps_back x bsz x K (0-idx) alph_terms.sub_(bwd_maxlens[t - steps_back:t].expand_as( alph_terms)) # steps_back x bsz x K alph[t] = logsumexp0(alph_terms) # bsz x K if t < seqlen: # alph*_t(k) = log \sum_j alph_t(j) p(q_{t+1}=k | q_t = j) # get bsz x K x K trans logprobs, viz., p(q_{t+1}=j|i) w/ 0th dim i, 2nd dim j tps = trans_logprobs[ t - 1] # N.B. trans_logprobs[t] is p(q_{t+1}) and 0-indexed alph_t = alph[t] # bsz x K, viz, p(x, j) alph_star_terms = ( tps.transpose(0, 1) # K x bsz x K + alph_t.unsqueeze(2).expand(bsz, K, K).transpose(0, 1)) alph_star[t] = logsumexp0(alph_star_terms) if not just_alphas: # beta*_t(k) = log \sum_l beta_{t+l}(k) p(x_{t+1:t+l}) p(l_t) beta_star_terms = ( beta[seqlen - t + 1:seqlen - t + 1 + steps_back] # steps_back x bsz x K + fwd_obs_logprobs[:steps_back, seqlen - t] # steps_back x bsz x K ).sub_(math.log(steps_back)) # steps_back x bsz x K beta_star[seqlen - t] = logsumexp0(beta_star_terms) if seqlen - t > 0: # beta_t(j) = log \sum_k beta*_t(k) p(q_{t+1} = k | q_t=j) betastar_nt = beta_star[seqlen - t] # bsz x K # get bsz x K x K trans logprobs, viz., p(q_{t+1}=j|i) w/ 0th dim i, 2nd dim j tps = trans_logprobs[ seqlen - t - 1] # N.B. trans_logprobs[t] is p(q_{t+1}) and 0-idxed beta_terms = betastar_nt.unsqueeze(1).expand( bsz, K, K) + tps # bsz x K x K #print logsumexp2(beta_terms).size() #print beta[seqlen-1].size() beta[seqlen - t] = logsumexp2(beta_terms) # bsz x K return alph, alph_star, beta, beta_star