def collect_label(probability): labels_2 = probability[pred_len - 1, T.arange(batch).type(longX), 2 * token_len - 1] labels_1 = probability[pred_len - 1, T.arange(batch).type(longX), 2 * token_len] labels_prob = log_sum_exp(labels_2, labels_1) return labels_prob
def ctc_ent_loss_log(pred, pred_len, token, token_len, blank=0): ''' :param pred: (Time, batch, voca_size+1) :param pred_len: (batch) :param token: (batch, U) :param token_len: (batch) :out alpha: (Time, batch, 2U+1) ∑p(π|x) :out beta: (Time, batch, 2U+1) ∑p(π|x)logp(π|x) :out H: -beta/alpha+log(alpha) ''' Time, batch = pred.size(0), pred.size(1) U = token.size(1) eps_nan = -1e8 eps = 1e-8 # token_with_blank token_with_blank = T.cat((T.zeros(batch, U, 1).type(longX), token[:, :, None]), dim=2).view(batch, -1) # (batch, 2U) token_with_blank = T.cat((token_with_blank, T.zeros(batch, 1).type(longX)), dim=1) # (batch, 2U+1) length = token_with_blank.size(1) pred = pred[T.arange(0, Time).type(longX)[:, None, None], T.arange(0, batch).type(longX)[None, :, None], token_with_blank[None, :]] # (T, batch, 2U+1) # recurrence relation sec_diag = T.cat((T.zeros((batch, 2)).type(floatX), T.ne(token_with_blank[:, :-2], token_with_blank[:, 2:]).type(floatX)), dim=1) * T.ne(token_with_blank, blank).type(floatX) # (batch, 2U+1) recurrence_relation = (m_eye(length) + m_eye(length, k=1)).repeat(batch, 1, 1) + m_eye(length, k=2).repeat(batch, 1, 1) * sec_diag[:, None, :] # (batch, 2U+1, 2U+1) recurrence_relation = eps_nan * (T.ones_like(recurrence_relation) - recurrence_relation) # alpha alpha_t = T.cat((pred[0, :, :2], T.ones(batch, 2*U-1).type(floatX)*eps_nan), dim=1) # (batch, 2U+1) beta_t = T.cat((pred[0, :, :2] + T.log(-pred[0, :, :2]+eps), T.ones(batch, 2*U-1).type(floatX)*eps_nan), dim=1) # (batch, 2U+1) alphas = alpha_t[None] # (1, batch, 2U+1) betas = beta_t[None] # (1, batch, 2U+1) # dynamic programming # (T, batch, 2U+1) for t in T.arange(1, Time).type(longX): alpha_t = log_batch_dot(alpha_t, recurrence_relation) + pred[t] beta_t = log_sum_exp(log_batch_dot(beta_t, recurrence_relation) + pred[t], T.log(-pred[t]+eps) + alpha_t) alphas = T.cat((alphas, alpha_t[None]), dim=0) betas = T.cat((betas, beta_t[None]), dim=0) def collect_label(probability): labels_2 = probability[pred_len-1, T.arange(batch).type(longX), 2*token_len-1] labels_1 = probability[pred_len-1, T.arange(batch).type(longX), 2*token_len] labels_prob = log_sum_exp(labels_2, labels_1) return labels_prob alpha = collect_label(alphas) beta = collect_label(betas) H = T.exp(beta-alpha) + alpha costs = -alpha return H, costs
def seg_ctc_ent_loss_log(pred, pred_len, token, token_len, uniform_mask, blank=0): ''' alpha(t, b, i) means the probability of end with output token i till time t beta(t, b, j, 2) means the probability of output only token j, from time t to now :param pred: (Time, batch, voca_size+1) :param pred_len: (batch,) :param token: (batch, U=token_len) :param token_len: (batch) :param blank: 0 :out alpha: (Time, batch, 2U+1) ∑p(π|x) :out beta: (Time, batch, 2U+1) ∑p(π|x)logp(π|x) :out H: -beta/alpha+log(alpha) :return: cost ''' eps_nan = -1e8 eps = 1e-6 Time, batch = pred.size(0), pred.size(1) U = token.size(1) token_with_blank = T.cat( (T.zeros(batch, U, 1).type(longX), token[:, :, None]), dim=2) # (batch, U, 2) pred_blank = pred[:, :, 0] # (Time, batch) pred = pred[T.arange(0, Time).type(longX)[:, None, None, None], T.arange(0, batch).type(longX)[None, :, None, None], token_with_blank[None, :]] # (Time, batch, U, 2) token_equals = T.nonzero(T.eq(token[:, :-1], token[:, 1:])) #batch, U-1 if len(token_equals.size()) == 2: te_b = token_equals[:, 0] te_u = token_equals[:, 1] have_equal = True else: have_equal = False betas = T.cat( (pred[0, None], T.ones(Time - 1, batch, U, 2).type(floatX) * eps_nan), dim=0) betas_ent = T.cat((pred[0, None] + T.log(-pred[0, None] + eps), T.ones(Time - 1, batch, U, 2).type(floatX) * eps_nan), dim=0) # (Time, batch, U, 2) alphas = T.cat( (pred[0, :, 0, 1, None], T.ones(batch, U - 1).type(floatX) * eps_nan), dim=1)[None] alphas_ent = T.cat( (pred[0, :, 0, 1, None] + T.log(-pred[0, :, 0, 1, None] + eps), T.ones(batch, U - 1).type(floatX) * eps_nan), dim=1)[None] # (1, batch, U) batch_range = T.arange(0, batch).type(longX) labels = alphas[-1][batch_range, token_len-1][None].clone() + \ (1-uniform_mask[0].type(floatX))*eps_nan# prob of emit the last token till now labels_ent = alphas_ent[-1][batch_range, token_len-1][None].clone() + \ (1-uniform_mask[0].type(floatX))*eps_nan# prob of emit the last token till now # (1, batch) for t in T.arange(1, Time).type(longX): betas[:t] = T.cat((betas[:t, :, :, 0, None], \ log_sum_exp(betas[:t, :, :, 0, None], betas[:t, :, :, 1, None])), dim=-1) \ + pred[t, None] betas[t] = pred[t] betas_ent[:t] = log_sum_exp(T.cat((betas_ent[:t, :, :, 0, None], \ log_sum_exp(betas_ent[:t, :, :, 0, None], \ betas_ent[:t, :, :, 1, None])), dim=-1) \ + pred[t, None], \ betas[:t].clone() + T.log(-pred[t, None] + eps)) betas_ent[t] = pred[t] + T.log(-pred[t] + eps) alphas_t = T.cat((betas[0, :, 0, 1][:, None].clone() + \ (1-uniform_mask[-t, :, None].type(floatX)) * eps_nan, \ log_sum_exp_axis(alphas[:, :, :-1] + betas[1:t+1, :, 1:, -1].clone(), \ uniform_mask[-t:, :, None].expand(t.item(), batch, U-1), dim=0)), dim=1) alphas_t_ent = T.cat((betas_ent[0, :, 0, 1][:, None].clone() + \ (1-uniform_mask[-t, :, None].type(floatX)) * eps_nan, \ log_sum_exp_axis(log_sum_exp(alphas_ent[:, :, :-1] + betas[1:t+1, :, 1:, -1].clone(), \ alphas[:, :, :-1] + betas_ent[1:t+1, :, 1:, -1].clone()), \ uniform_mask[-t:, :, None].expand(t.item(), batch, U-1), dim=0)), dim=1) if have_equal: alphas_t[te_b, 1+te_u] = log_sum_exp_axis(alphas[:-1][:, te_b, te_u] \ + pred_blank[1:t][:, te_b] \ + betas[2:t+1, :, :, -1][:, te_b, 1+te_u].clone(), uniform_mask[-t+1:][:, te_b], dim=0).clone() if t >= 2 else eps_nan alphas_t_ent[te_b, 1+te_u] = log_sum_exp( log_sum_exp_axis(pred_blank[1:t][:, te_b] + \ log_sum_exp(alphas_ent[:-1][:, te_b, te_u] + betas[2:t+1, :, :, -1][:, te_b, 1+te_u].clone(), \ alphas[:-1][:, te_b, te_u] + betas_ent[2:t+1, :, :, -1][:, te_b, 1+te_u].clone() \ ), uniform_mask[-t+1:][:, te_b], dim=0).clone(), \ log_sum_exp_axis(alphas[:-1][:, te_b, te_u] \ + pred_blank[1:t][:, te_b] \ + T.log(-pred_blank[1:t][:, te_b] + eps) \ + betas[2:t+1, :, :, -1][:, te_b, 1+te_u].clone(), uniform_mask[-t+1:][:, te_b], dim=0).clone()) if t >= 2 else eps_nan alphas = T.cat((alphas, alphas_t[None, :]), dim=0) alphas_ent = T.cat((alphas_ent, alphas_t_ent[None, :]), dim=0) labels_t = log_sum_exp( labels[-1] + pred_blank[t] + (1 - uniform_mask[t].type(floatX)) * eps_nan, alphas[-1][batch_range, token_len - 1]) labels_t_ent = log_sum_exp(labels_ent[-1] + pred_blank[t] + (1-uniform_mask[t].type(floatX))*eps_nan, \ labels[-1] + pred_blank[t] + T.log(-pred_blank[t] + eps) + (1-uniform_mask[t].type(floatX))*eps_nan, \ alphas_ent[-1][batch_range, token_len-1]) labels = T.cat((labels, labels_t[None]), dim=0).clone() labels_ent = T.cat((labels_ent, labels_t_ent[None]), dim=0).clone() lt = labels[pred_len - 1, batch_range] lt_ent = labels_ent[pred_len - 1, batch_range] H = T.exp(lt_ent - lt) + lt costs = -lt return H, costs # (batch)