def attention_mask(attention_states_sequence_len, input, name=None): r"""AttentionMask Args: attention_states_sequence_len: A `Tensor` of type `int64`. input: A `Tensor` of type `float32`. fill_value: A `float`. name: A name for the operation (optional). Returns: A `Tensor` of type `float32`. """ return gen_attention_mask_ops._attention_mask( attention_states_sequence_len=attention_states_sequence_len, input=input, fill_value=-np.finfo(np.float32).max, name=name)
def _AttentionMaskGrad(op, *grad): attention_mask_grad = gen_attention_mask_ops._attention_mask( attention_states_sequence_len=op.inputs[0], input=grad[0], fill_value=0.0) return [None] + [attention_mask_grad]