예제 #1
0
def just_fwd(pi, trans_logprobs, bwd_obs_logprobs, constraints=None):
    """
    pi               - bsz x K
    bwd_obs_logprobs - L x T x bsz x K, obs probs ending at t
    trans_logprobs   - T-1 x bsz x K x K, trans_logprobs[t] = p(q_{t+1} | q_t)
    """
    neginf = -1e38  # -float("inf")
    L, seqlen, bsz, K = bwd_obs_logprobs.size()
    # we'll be 1-indexed for alphas and betas
    alph = [None] * (seqlen + 1)
    alph_star = [None] * (seqlen + 1)
    alph_star[0] = pi
    mask = trans_logprobs.new(L, bsz, K)

    bwd_maxlens = trans_logprobs.new(seqlen).fill_(
        L)  # store max possible length generated from t
    bwd_maxlens[-L:].copy_(torch.arange(L, 0, -1))
    bwd_maxlens = bwd_maxlens.log_().view(seqlen, 1, 1)

    for t in xrange(1, seqlen + 1):
        steps_back = min(L, t)

        if constraints is not None and constraints[t] is not None:
            tmask = mask.narrow(0, 0, steps_back).zero_()
            # steps_back x bsz x K -> steps_back*bsz x K
            tmask.view(-1, K).index_fill_(0, constraints[t], neginf)

        # alph_t(j) = log \sum_l p(x_{t-l+1:t}) alph*_{t-l} p(l_t)
        alph_terms = (
            torch.stack(alph_star[t - steps_back:t])  # steps_back x bsz x K
            + bwd_obs_logprobs[-steps_back:,
                               t - 1]  # steps_back x bsz x K (0-idx)
            - bwd_maxlens[t - steps_back:t].expand(steps_back, bsz, K))

        if constraints is not None and constraints[t] is not None:
            alph_terms = alph_terms + tmask  #Variable(tmask)

        alph[t] = logsumexp0(alph_terms)  # bsz x K

        if t < seqlen:
            # alph*_t(k) = log \sum_j alph_t(j) p(q_{t+1}=k | q_t = j)
            # get bsz x K x K trans logprobs, viz., p(q_{t+1}=j|i) w/ 0th dim i, 2nd dim j
            tps = trans_logprobs[
                t - 1]  # N.B. trans_logprobs[t] is p(q_{t+1}) and 0-indexed
            alph_t = alph[t]  # bsz x K, viz, p(x, j)
            alph_star_terms = (
                tps.transpose(0, 1)  # K x bsz x K
                + alph_t.unsqueeze(2).expand(bsz, K, K).transpose(0, 1))
            alph_star[t] = logsumexp0(alph_star_terms)

    return alph, alph_star
예제 #2
0
파일: infc.py 프로젝트: opcheese/S2S_Temp
def _just_bwd(trans_logprobs,
              fwd_obs_logprobs,
              len_logprobs,
              constraints=None):
    """
    fwd_obs_logprobs - L x T x bsz x K, obs probs starting at t
    trans_logprobs   - T-1 x bsz x K x K, trans_logprobs[t] = p(q_{t+1} | q_t)
    """
    neginf = -1e38  # -float("inf")
    L, seqlen, bsz, K = fwd_obs_logprobs.size()

    # we'll be 1-indexed for alphas and betas
    beta = [None] * (seqlen + 1)
    beta_star = [None] * (seqlen + 1)
    beta[seqlen] = Variable(trans_logprobs.data.new(bsz, K).zero_())
    mask = trans_logprobs.data.new(L, bsz, K)

    for t in range(1, seqlen + 1):
        steps_fwd = min(L, t)

        len_terms = len_logprobs[min(L - 1, steps_fwd - 1)]  # steps_fwd x K

        # print(constraints)
        # print(len(constraints), seqlen)
        if constraints is not None and constraints[seqlen - t + 1] is not None:
            tmask = mask.narrow(0, 0, steps_fwd).zero_()
            # steps_fwd x bsz x K -> steps_fwd*bsz x K
            tmask.view(-1, K).index_fill_(0, constraints[seqlen - t + 1],
                                          neginf)

        # beta*_t(k) = log \sum_l beta_{t+l}(k) p(x_{t+1:t+l}) p(l_t)
        beta_star_terms = (
            torch.stack(beta[seqlen - t + 1:seqlen - t + 1 +
                             steps_fwd])  # steps_fwd x bsz x K
            + fwd_obs_logprobs[:steps_fwd, seqlen - t]  # steps_fwd x bsz x K
            #- math.log(steps_fwd)) # steps_fwd x bsz x K
            + len_terms.unsqueeze(1).expand(steps_fwd, bsz, K))

        if constraints is not None and constraints[seqlen - t + 1] is not None:
            beta_star_terms = beta_star_terms + Variable(tmask)

        beta_star[seqlen - t] = logsumexp0(beta_star_terms)
        if seqlen - t > 0:
            # beta_t(j) = log \sum_k beta*_t(k) p(q_{t+1} = k | q_t=j)
            betastar_nt = beta_star[seqlen - t]  # bsz x K
            # get bsz x K x K trans logprobs, viz., p(q_{t+1}=j|i) w/ 0th dim i, 2nd dim j
            tps = trans_logprobs[
                seqlen - t -
                1]  # N.B. trans_logprobs[t] is p(q_{t+1}) and 0-idxed
            beta_terms = betastar_nt.unsqueeze(1).expand(
                bsz, K, K) + tps  # bsz x K x K
            beta[seqlen - t] = logsumexp2(beta_terms)  # bsz x K

    return beta, beta_star
예제 #3
0
def fwd_bwd(pi,
            trans_logprobs,
            bwd_obs_logprobs,
            fwd_obs_logprobs,
            just_alphas=False):
    """
    pi               - bsz x K
    bwd_obs_logprobs - L x T x bsz x K, obs probs ending at t
    fwd_obs_logprobs - L x T x bsz x K, obs probs starting at t
    trans_logprobs   - T-1 x bsz x K x K, trans_logprobs[t] = p(q_{t+1} | q_t)
    """
    L, seqlen, bsz, K = fwd_obs_logprobs.size()
    # we'll be 1-indexed for alphas and betas
    alph = trans_logprobs.new(seqlen + 1, bsz, K).fill_(-float("inf"))
    alph_star = trans_logprobs.new(seqlen + 1, bsz, K).fill_(-float("inf"))
    alph_star[0].copy_(pi)
    bwd_maxlens = trans_logprobs.new(seqlen).fill_(
        L)  # store max possible length generated from t
    bwd_maxlens[-L:].copy_(torch.arange(L, 0, -1))
    bwd_maxlens = bwd_maxlens.log_().view(seqlen, 1, 1)

    if not just_alphas:
        beta = trans_logprobs.new(seqlen + 1, bsz, K).fill_(-float("inf"))
        beta_star = trans_logprobs.new(seqlen + 1, bsz, K).fill_(-float("inf"))
        beta[seqlen].fill_(0)
    else:
        beta, beta_star = None, None
    for t in xrange(1, seqlen + 1):
        steps_back = min(L, t)
        # alph_t(j) = log \sum_l p(x_{t-l+1:t}) alph*_{t-l} p(l_t)
        alph_terms = (
            alph_star[t - steps_back:t]  # steps_back x bsz x K
            + bwd_obs_logprobs[-steps_back:, t - 1]
        )  # steps_back x bsz x K (0-idx)
        alph_terms.sub_(bwd_maxlens[t - steps_back:t].expand_as(
            alph_terms))  # steps_back x bsz x K

        alph[t] = logsumexp0(alph_terms)  # bsz x K
        if t < seqlen:
            # alph*_t(k) = log \sum_j alph_t(j) p(q_{t+1}=k | q_t = j)
            # get bsz x K x K trans logprobs, viz., p(q_{t+1}=j|i) w/ 0th dim i, 2nd dim j
            tps = trans_logprobs[
                t - 1]  # N.B. trans_logprobs[t] is p(q_{t+1}) and 0-indexed
            alph_t = alph[t]  # bsz x K, viz, p(x, j)
            alph_star_terms = (
                tps.transpose(0, 1)  # K x bsz x K
                + alph_t.unsqueeze(2).expand(bsz, K, K).transpose(0, 1))
            alph_star[t] = logsumexp0(alph_star_terms)
        if not just_alphas:
            # beta*_t(k) = log \sum_l beta_{t+l}(k) p(x_{t+1:t+l}) p(l_t)
            beta_star_terms = (
                beta[seqlen - t + 1:seqlen - t + 1 +
                     steps_back]  # steps_back x bsz x K
                + fwd_obs_logprobs[:steps_back,
                                   seqlen - t]  # steps_back x bsz x K
            ).sub_(math.log(steps_back))  # steps_back x bsz x K
            beta_star[seqlen - t] = logsumexp0(beta_star_terms)
            if seqlen - t > 0:
                # beta_t(j) = log \sum_k beta*_t(k) p(q_{t+1} = k | q_t=j)
                betastar_nt = beta_star[seqlen - t]  # bsz x K
                # get bsz x K x K trans logprobs, viz., p(q_{t+1}=j|i) w/ 0th dim i, 2nd dim j
                tps = trans_logprobs[
                    seqlen - t -
                    1]  # N.B. trans_logprobs[t] is p(q_{t+1}) and 0-idxed
                beta_terms = betastar_nt.unsqueeze(1).expand(
                    bsz, K, K) + tps  # bsz x K x K
                #print logsumexp2(beta_terms).size()
                #print beta[seqlen-1].size()
                beta[seqlen - t] = logsumexp2(beta_terms)  # bsz x K

    return alph, alph_star, beta, beta_star