Пример #1
0
 def collect_label(probability):
     labels_2 = probability[pred_len - 1,
                            T.arange(batch).type(longX), 2 * token_len - 1]
     labels_1 = probability[pred_len - 1,
                            T.arange(batch).type(longX), 2 * token_len]
     labels_prob = log_sum_exp(labels_2, labels_1)
     return labels_prob
Пример #2
0
def ctc_ent_loss_log(pred, pred_len, token, token_len, blank=0):
    '''
    :param pred: (Time, batch, voca_size+1)
    :param pred_len: (batch)
    :param token: (batch, U)
    :param token_len: (batch)

    :out alpha: (Time, batch, 2U+1) ∑p(π|x)
    :out beta: (Time, batch, 2U+1)  ∑p(π|x)logp(π|x)
    :out H: -beta/alpha+log(alpha)
    '''
    Time, batch = pred.size(0), pred.size(1)
    U = token.size(1)
    eps_nan = -1e8
    eps = 1e-8

    # token_with_blank
    token_with_blank = T.cat((T.zeros(batch, U, 1).type(longX), token[:, :, None]), dim=2).view(batch, -1)    # (batch, 2U)
    token_with_blank = T.cat((token_with_blank, T.zeros(batch, 1).type(longX)), dim=1)  # (batch, 2U+1)
    length = token_with_blank.size(1)

    pred = pred[T.arange(0, Time).type(longX)[:, None, None], T.arange(0, batch).type(longX)[None, :, None], token_with_blank[None, :]]  # (T, batch, 2U+1)

    # recurrence relation
    sec_diag = T.cat((T.zeros((batch, 2)).type(floatX), T.ne(token_with_blank[:, :-2], token_with_blank[:, 2:]).type(floatX)), dim=1) * T.ne(token_with_blank, blank).type(floatX)	# (batch, 2U+1)
    recurrence_relation = (m_eye(length) + m_eye(length, k=1)).repeat(batch, 1, 1) + m_eye(length, k=2).repeat(batch, 1, 1) * sec_diag[:, None, :]	# (batch, 2U+1, 2U+1)
    recurrence_relation = eps_nan * (T.ones_like(recurrence_relation) - recurrence_relation)

    # alpha
    alpha_t = T.cat((pred[0, :, :2], T.ones(batch, 2*U-1).type(floatX)*eps_nan), dim=1) # (batch, 2U+1)
    beta_t = T.cat((pred[0, :, :2] + T.log(-pred[0, :, :2]+eps),
                    T.ones(batch, 2*U-1).type(floatX)*eps_nan), dim=1) # (batch, 2U+1)

    alphas = alpha_t[None] # (1, batch, 2U+1)
    betas = beta_t[None] # (1, batch, 2U+1)

    # dynamic programming
    # (T, batch, 2U+1)
    for t in T.arange(1, Time).type(longX):
        alpha_t = log_batch_dot(alpha_t, recurrence_relation) + pred[t]
        beta_t = log_sum_exp(log_batch_dot(beta_t, recurrence_relation) + pred[t], T.log(-pred[t]+eps) + alpha_t)

        alphas = T.cat((alphas, alpha_t[None]), dim=0)
        betas = T.cat((betas, beta_t[None]), dim=0)

    def collect_label(probability):
        labels_2 = probability[pred_len-1, T.arange(batch).type(longX), 2*token_len-1]
        labels_1 = probability[pred_len-1, T.arange(batch).type(longX), 2*token_len]
        labels_prob = log_sum_exp(labels_2, labels_1)
        return labels_prob

    alpha = collect_label(alphas)
    beta = collect_label(betas)

    H = T.exp(beta-alpha) + alpha
    costs = -alpha
    return H, costs
Пример #3
0
def seg_ctc_ent_loss_log(pred,
                         pred_len,
                         token,
                         token_len,
                         uniform_mask,
                         blank=0):
    '''
    alpha(t, b, i) means the probability of end with output token i till time t
    beta(t, b, j, 2) means the probability of output only token j, from time t to now
    :param pred: (Time, batch, voca_size+1)
    :param pred_len: (batch,)
    :param token: (batch, U=token_len)
    :param token_len: (batch)
    :param blank: 0

    :out alpha: (Time, batch, 2U+1) ∑p(π|x)
    :out beta: (Time, batch, 2U+1)  ∑p(π|x)logp(π|x)
    :out H: -beta/alpha+log(alpha)
    :return: cost
    '''
    eps_nan = -1e8
    eps = 1e-6

    Time, batch = pred.size(0), pred.size(1)
    U = token.size(1)

    token_with_blank = T.cat(
        (T.zeros(batch, U, 1).type(longX), token[:, :, None]),
        dim=2)  # (batch, U, 2)
    pred_blank = pred[:, :, 0]  # (Time, batch)
    pred = pred[T.arange(0, Time).type(longX)[:, None, None, None],
                T.arange(0, batch).type(longX)[None, :, None, None],
                token_with_blank[None, :]]
    # (Time, batch, U, 2)

    token_equals = T.nonzero(T.eq(token[:, :-1], token[:, 1:]))  #batch, U-1
    if len(token_equals.size()) == 2:
        te_b = token_equals[:, 0]
        te_u = token_equals[:, 1]
        have_equal = True
    else:
        have_equal = False

    betas = T.cat(
        (pred[0, None], T.ones(Time - 1, batch, U, 2).type(floatX) * eps_nan),
        dim=0)
    betas_ent = T.cat((pred[0, None] + T.log(-pred[0, None] + eps),
                       T.ones(Time - 1, batch, U, 2).type(floatX) * eps_nan),
                      dim=0)
    # (Time, batch, U, 2)
    alphas = T.cat(
        (pred[0, :, 0, 1, None], T.ones(batch, U - 1).type(floatX) * eps_nan),
        dim=1)[None]
    alphas_ent = T.cat(
        (pred[0, :, 0, 1, None] + T.log(-pred[0, :, 0, 1, None] + eps),
         T.ones(batch, U - 1).type(floatX) * eps_nan),
        dim=1)[None]
    # (1, batch, U)

    batch_range = T.arange(0, batch).type(longX)
    labels = alphas[-1][batch_range, token_len-1][None].clone() + \
            (1-uniform_mask[0].type(floatX))*eps_nan# prob of emit the last token till now
    labels_ent = alphas_ent[-1][batch_range, token_len-1][None].clone() + \
            (1-uniform_mask[0].type(floatX))*eps_nan# prob of emit the last token till now
    # (1, batch)

    for t in T.arange(1, Time).type(longX):
        betas[:t] = T.cat((betas[:t, :, :, 0, None], \
                log_sum_exp(betas[:t, :, :, 0, None], betas[:t, :, :, 1, None])), dim=-1) \
                + pred[t, None]
        betas[t] = pred[t]
        betas_ent[:t] = log_sum_exp(T.cat((betas_ent[:t, :, :, 0, None], \
                          log_sum_exp(betas_ent[:t, :, :, 0, None], \
                            betas_ent[:t, :, :, 1, None])), dim=-1) \
                            + pred[t, None], \
                          betas[:t].clone() + T.log(-pred[t, None] + eps))
        betas_ent[t] = pred[t] + T.log(-pred[t] + eps)

        alphas_t = T.cat((betas[0, :, 0, 1][:, None].clone() + \
                (1-uniform_mask[-t, :, None].type(floatX)) * eps_nan, \
                log_sum_exp_axis(alphas[:, :, :-1] + betas[1:t+1, :, 1:, -1].clone(), \
                    uniform_mask[-t:, :, None].expand(t.item(), batch, U-1), dim=0)), dim=1)
        alphas_t_ent = T.cat((betas_ent[0, :, 0, 1][:, None].clone() + \
                (1-uniform_mask[-t, :, None].type(floatX)) * eps_nan, \
                log_sum_exp_axis(log_sum_exp(alphas_ent[:, :, :-1] + betas[1:t+1, :, 1:, -1].clone(), \
                    alphas[:, :, :-1] + betas_ent[1:t+1, :, 1:, -1].clone()), \
                    uniform_mask[-t:, :, None].expand(t.item(), batch, U-1), dim=0)), dim=1)
        if have_equal:
            alphas_t[te_b, 1+te_u] = log_sum_exp_axis(alphas[:-1][:, te_b, te_u] \
                                           + pred_blank[1:t][:, te_b] \
                                           + betas[2:t+1, :, :, -1][:, te_b, 1+te_u].clone(),
                                           uniform_mask[-t+1:][:, te_b],
                                           dim=0).clone() if t >= 2 else eps_nan
            alphas_t_ent[te_b, 1+te_u] = log_sum_exp(
                                            log_sum_exp_axis(pred_blank[1:t][:, te_b] + \
                                                log_sum_exp(alphas_ent[:-1][:, te_b, te_u] + betas[2:t+1, :, :, -1][:, te_b, 1+te_u].clone(), \
                                                 alphas[:-1][:, te_b, te_u] + betas_ent[2:t+1, :, :, -1][:, te_b, 1+te_u].clone() \
                                                ),
                                                uniform_mask[-t+1:][:, te_b],
                                                dim=0).clone(), \
                                            log_sum_exp_axis(alphas[:-1][:, te_b, te_u] \
                                                + pred_blank[1:t][:, te_b] \
                                                + T.log(-pred_blank[1:t][:, te_b] + eps) \
                                                + betas[2:t+1, :, :, -1][:, te_b, 1+te_u].clone(),
                                                uniform_mask[-t+1:][:, te_b],
                                                dim=0).clone()) if t >= 2 else eps_nan

        alphas = T.cat((alphas, alphas_t[None, :]), dim=0)
        alphas_ent = T.cat((alphas_ent, alphas_t_ent[None, :]), dim=0)
        labels_t = log_sum_exp(
            labels[-1] + pred_blank[t] +
            (1 - uniform_mask[t].type(floatX)) * eps_nan,
            alphas[-1][batch_range, token_len - 1])
        labels_t_ent = log_sum_exp(labels_ent[-1] + pred_blank[t] + (1-uniform_mask[t].type(floatX))*eps_nan, \
                labels[-1] + pred_blank[t] + T.log(-pred_blank[t] + eps) + (1-uniform_mask[t].type(floatX))*eps_nan, \
                alphas_ent[-1][batch_range, token_len-1])

        labels = T.cat((labels, labels_t[None]), dim=0).clone()
        labels_ent = T.cat((labels_ent, labels_t_ent[None]), dim=0).clone()

    lt = labels[pred_len - 1, batch_range]
    lt_ent = labels_ent[pred_len - 1, batch_range]

    H = T.exp(lt_ent - lt) + lt
    costs = -lt
    return H, costs  # (batch)