def __init__(self, embed_dim, num_heads, branches, d_kv=None, name=None): super().__init__() MultiheadAttention.global_count += 1 self.instance = 0 assert embed_dim % num_heads == 0, 'embed_dim must be divisible by num_heads' self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads if (d_kv == None): self.inner_dim = embed_dim self.head_dim = embed_dim // num_heads else: self.inner_dim = d_kv * num_heads self.head_dim = d_kv if (branches == 0): self.ENABLE_SUBGRAPH = False self.BRANCHES = 0 else: self.ENABLE_SUBGRAPH = True self.BRANCHES = branches # Module name self.name = name if not self.name: self.name = f'multiheadattention{MultiheadAttention.global_count}' # Weights for fully-connected layers self.query_weights = [ lbann.Weights(initializer=lbann.GlorotNormalInitializer(), name=f'{self.name}_query_matrix'), lbann.Weights(initializer=lbann.ConstantInitializer(value=0), name=f'{self.name}_query_bias'), ] self.key_weights = [ lbann.Weights(initializer=lbann.GlorotNormalInitializer(), name=f'{self.name}_key_matrix'), lbann.Weights(initializer=lbann.ConstantInitializer(value=0), name=f'{self.name}_key_bias'), ] self.value_weights = [ lbann.Weights(initializer=lbann.GlorotNormalInitializer(), name=f'{self.name}_value_matrix'), lbann.Weights(initializer=lbann.ConstantInitializer(value=0), name=f'{self.name}_value_bias'), ] #Channelwise FC in SubGraph self.output_weights = [] for head in range(branches): self.output_weights.append([ lbann.Weights(initializer=lbann.GlorotNormalInitializer(), name=f'{self.name}_head{head}_output_matrix'), lbann.Weights(initializer=lbann.ConstantInitializer(value=0), name=f'{self.name}_head{head}_output_bias'), ])
def __init__(self, embed_dim, num_heads, name=None): super().__init__() MultiheadAttention.global_count += 1 self.instance = 0 assert embed_dim % num_heads == 0, 'embed_dim must be divisible by num_heads' self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads # Module name self.name = name if not self.name: self.name = f'multiheadattention{MultiheadAttention.global_count}' # Weights for fully-connected layers self.query_weights = [ lbann.Weights(initializer=lbann.GlorotNormalInitializer(), name=f'{self.name}_query_matrix'), lbann.Weights(initializer=lbann.ConstantInitializer(value=0), name=f'{self.name}_query_bias'), ] self.key_weights = [ lbann.Weights(initializer=lbann.GlorotNormalInitializer(), name=f'{self.name}_key_matrix'), lbann.Weights(initializer=lbann.ConstantInitializer(value=0), name=f'{self.name}_key_bias'), ] self.value_weights = [ lbann.Weights(initializer=lbann.GlorotNormalInitializer(), name=f'{self.name}_value_matrix'), lbann.Weights(initializer=lbann.ConstantInitializer(value=0), name=f'{self.name}_value_bias'), ] self.output_weights = [ lbann.Weights(initializer=lbann.GlorotNormalInitializer(), name=f'{self.name}_output_matrix'), lbann.Weights(initializer=lbann.ConstantInitializer(value=0), name=f'{self.name}_output_bias'), ]
def __init__( self, embed_dim=512, num_heads=8, feedforward_dim=2048, dropout=0.1, name=None, ): TransformerDecoderLayer.global_count += 1 self.instance = 0 self.embed_dim = embed_dim self.feedforward_dim = feedforward_dim self.dropout_prob = dropout # Module name self.name = name if not self.name: self.name = f'transformerdecoderlayer{TransformerDecoderLayer.global_count}' # Layer modules self.attention1 = lbann.modules.transformer.MultiheadAttention( embed_dim, num_heads, name=f'{self.name}_attention1') self.attention2 = lbann.modules.transformer.MultiheadAttention( embed_dim, num_heads, name=f'{self.name}_attention2') # Weights for fully-connected layers self.fc1_weights = [ lbann.Weights(initializer=lbann.HeNormalInitializer(), name=f'{self.name}_fc1_matrix'), lbann.Weights(initializer=lbann.ConstantInitializer(value=0), name=f'{self.name}_fc1_bias'), ] self.fc2_weights = [ lbann.Weights(initializer=lbann.GlorotNormalInitializer(), name=f'{self.name}_fc2_matrix'), lbann.Weights(initializer=lbann.ConstantInitializer(value=0), name=f'{self.name}_fc2_bias'), ]