示例#1
0
def extrac_subject(inputs):
    """根据subject_ids从output中取出subject的向量表征
    """
    output, subject_ids = inputs
    subject_ids = K.cast(subject_ids, 'int32')
    start = batch_gather(output, subject_ids[:, :1])
    end = batch_gather(output, subject_ids[:, 1:])
    subject = K.concatenate([start, end], 2)
    return subject[:, 0]
def extrac_subject(inputs):
    """根据subject_ids从output中取出subject的向量表征
    """
    output, subject_ids = inputs
    subject_ids = K.cast(
        subject_ids,
        'int32')  # Tensor("lambda_2/Cast:0", shape=(?, 2), dtype=int32)
    start = batch_gather(
        output, subject_ids[:, :1]
    )  # 取出start的向量 subject_ids表示每个btz中只有一个数字 Tensor("lambda_2/Gather/Reshape_3:0", shape=(?, 1, 768), dtype=float32)
    end = batch_gather(
        output, subject_ids[:, 1:]
    )  # Tensor("lambda_2/Gather_1/Reshape_3:0", shape=(?, 1, 768), dtype=float32)
    subject = K.concatenate(
        [start, end],
        2)  # Tensor("lambda_2/concat:0", shape=(?, 1, 1536), dtype=float32)
    return subject[:, 0]
示例#3
0
 def compute_seq2seq_loss(self, inputs, mask=None):
     y_true, y_mask, _, y_pred, _ = inputs
     y_true = y_true[:, 1:]  # 目标token_ids
     y_mask = y_mask[:, :-1] * y_mask[:, 1:]  # segment_ids,刚好指示了要预测的部分
     y_pred = y_pred[:, :-1]  # 预测序列,错开一位
     # 正loss
     pos_loss = batch_gather(y_pred, y_true[..., None])[..., 0]
     # 负loss
     y_pred = tf.nn.top_k(y_pred, k=k_sparse)[0]
     neg_loss = K.logsumexp(y_pred, axis=-1)
     # 总loss
     loss = neg_loss - pos_loss
     loss = K.sum(loss * y_mask) / K.sum(y_mask)
     return loss
 def call(self, input):
     input, output, label = input
     output = batch_gather(output, label)
     return K.gradients(output, [input])[0] * input
 def call(self, input):
     input, output, label = input
     label = K.cast(label, 'int32')
     output = batch_gather(output, label)
     return K.gradients(output, [input])[0] * input
示例#6
0
def extract_subject(inputs):
    output, subject_ids = inputs
    start = batch_gather(output, subject_ids[:, :1])
    end = batch_gather(output, subject_ids[:, 1:])
    subject = K.concatenate([start, end], 2)
    return subject[:, 0]