def __init__(self, num_heads, multi_head_output_size, input_node_size, name=None): super(CoreNetwork, self).__init__(name=name) self.num_heads = num_heads self.multi_head_output_size = multi_head_output_size self.output_linear = snt.Linear(output_size=input_node_size) self.FFN = snt.nets.MLP([32, input_node_size], activate_final=False) # Feed forward network self.normalization = lambda x: (x - tf.reduce_mean(x) ) / tf.math.reduce_std(x) self.ln1 = snt.LayerNorm(axis=1, eps=1e-6, create_scale=True, create_offset=True) self.ln2 = snt.LayerNorm(axis=1, eps=1e-6, create_scale=True, create_offset=True) self.v_linear = MultiHeadLinear(output_size=multi_head_output_size, num_heads=num_heads) # values self.k_linear = MultiHeadLinear(output_size=multi_head_output_size, num_heads=num_heads) # keys self.q_linear = MultiHeadLinear(output_size=multi_head_output_size, num_heads=num_heads) # queries self.self_attention = SelfAttention()
def __init__(self, num_heads, name=None): super(TransformerLayer, self).__init__(name=name) self.num_heads = num_heads self.ln1 = snt.LayerNorm(axis=-1, eps=1e-6, create_scale=True, create_offset=True, name='layer_norm1') self.ln2 = snt.LayerNorm(axis=-1, eps=1e-6, create_scale=True, create_offset=True, name='layer_norm2') self.ln_keys = snt.LayerNorm(axis=-1, eps=1e-6, create_scale=True, create_offset=True, name='layer_norm_keys') self.ln_queries = snt.LayerNorm(axis=-1, eps=1e-6, create_scale=True, create_offset=True, name='layer_norm_queries') self.self_attention = SelfAttention()