def attention_mask_window(attention_states_sequence_len, index, input, name=None): r"""AttentionMaskWindow Args: attention_states_sequence_len: A `Tensor` of type `int64`. input: A `Tensor` of type `float32`. fill_value: A `float`. name: A name for the operation (optional). Returns: A `Tensor` of type `float32`. """ return gen_attention_mask_ops._attention_mask_window( attention_states_sequence_len=attention_states_sequence_len, index=index, input=input, fill_value=-np.finfo(np.float32).max, name=name)
def _AttentionMaskWindowGrad(op, *grad): attention_mask_grad = gen_attention_mask_ops._attention_mask_window( attention_states_sequence_len=op.inputs[0], index=op.get_attr("index"), input=grad[0], fill_value=0.0) return [None] + [attention_mask_grad]