def testSimpleAttention(self): x = np.random.rand(5, 7, 1, 11) y = np.random.rand(5, 9, 1, 11) with self.test_session() as session: a = common_layers.simple_attention( tf.constant(x, dtype=tf.float32), tf.constant(y, dtype=tf.float32)) session.run(tf.global_variables_initializer()) res = session.run(a) self.assertEqual(res.shape, (5, 7, 1, 11))
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, train, bias=None): """Complete attention layer with preprocessing.""" separabilities = [hparams.separability, hparams.separability] if hparams.separability < 0: separabilities = [hparams.separability - 1, hparams.separability] targets_timed = common_layers.subseparable_conv_block( common_layers.add_timing_signal(targets_shifted), hparams.hidden_size, [((1, 1), (5, 1)), ((4, 1), (5, 1))], normalizer_fn=norm_fn, padding="LEFT", separabilities=separabilities, name="targets_time") if hparams.attention_type == "transformer": targets_timed = tf.squeeze(targets_timed, 2) target_shape = tf.shape(targets_timed) targets_segment = tf.zeros([target_shape[0], target_shape[1]]) target_attention_bias = common_attention.attention_bias( targets_segment, targets_segment, lower_triangular=True) inputs_attention_bias = tf.zeros([ tf.shape(inputs_encoded)[0], hparams.num_heads, tf.shape(targets_segment)[1], tf.shape(inputs_encoded)[1] ]) attention_dropout = hparams.attention_dropout * tf.to_float(train) qv = common_attention.multihead_attention(targets_timed, None, target_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, attention_dropout, name="self_attention", summaries=False) qv = common_attention.multihead_attention(qv, inputs_encoded, inputs_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, attention_dropout, name="encdec_attention", summaries=False) return tf.expand_dims(qv, 2) elif hparams.attention_type == "simple": targets_with_attention = common_layers.simple_attention( targets_timed, inputs_encoded, bias=bias, summaries=False) return norm_fn(targets_shifted + targets_with_attention, name="attn_norm")
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, train, bias=None): """Complete attention layer with preprocessing.""" separabilities = [hparams.separability, hparams.separability] if hparams.separability < 0: separabilities = [hparams.separability - 1, hparams.separability] targets_timed = common_layers.subseparable_conv_block( common_layers.add_timing_signal(targets_shifted), hparams.hidden_size, [((1, 1), (5, 1)), ((4, 1), (5, 1))], normalizer_fn=norm_fn, padding="LEFT", separabilities=separabilities, name="targets_time") if hparams.attention_type == "transformer": targets_timed = tf.squeeze(targets_timed, 2) target_shape = tf.shape(targets_timed) targets_segment = tf.zeros([target_shape[0], target_shape[1]]) target_attention_bias = common_attention.attention_bias( targets_segment, targets_segment, lower_triangular=True) inputs_attention_bias = tf.zeros([ tf.shape(inputs_encoded)[0], hparams.num_heads, tf.shape(targets_segment)[1], tf.shape(inputs_encoded)[1] ]) attention_dropout = hparams.attention_dropout * tf.to_float(train) qv = common_attention.multihead_attention( targets_timed, None, target_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, attention_dropout, name="self_attention", summaries=False) qv = common_attention.multihead_attention( qv, inputs_encoded, inputs_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, attention_dropout, name="encdec_attention", summaries=False) return tf.expand_dims(qv, 2) elif hparams.attention_type == "simple": targets_with_attention = common_layers.simple_attention( targets_timed, inputs_encoded, bias=bias, summaries=False) return norm_fn(targets_shifted + targets_with_attention, name="attn_norm")