def forward(self, x, mask=None): """Apply Transformer encoder layer. Args: x (lbann.Layer): Sequence of input vectors. mask (lbann.Layer, optional): Attention mask. Returns: lbann.Layer: Sequence of output vectors. """ self.instance += 1 name = f'{self.name}_instance{self.instance}' # Self-attention with residual connection y = self.attention(x, x, x, mask=mask) if self.dropout_prob > 0: y = lbann.Dropout( y, keep_prob=1 - self.dropout_prob, name=f'{name}_drop1', ) z = lbann.Sum(x, y, name=f'{name}_sum1') z = lbann.InstanceNorm(z, name=f'{name}_norm1') x = z # Feedforward network with residual connection y = lbann.ChannelwiseFullyConnected( x, weights=self.fc1_weights, output_channel_dims=[self.feedforward_dim], name=f'{name}_fc1', ) y = lbann.Relu(y, name=f'{name}_relu1') if self.dropout_prob > 0: y = lbann.Dropout( y, keep_prob=1 - self.dropout_prob, name=f'{name}_drop2', ) y = lbann.ChannelwiseFullyConnected( y, weights=self.fc2_weights, output_channel_dims=[self.embed_dim], name=f'{name}_fc2', ) if self.dropout_prob > 0: y = lbann.Dropout( y, keep_prob=1 - self.dropout_prob, name=f'{name}_drop3', ) z = lbann.Sum(x, y, name=f'{name}_sum2') z = lbann.InstanceNorm(z, name=f'{name}_norm2') return z
def BondEncoder(edge_feature_columns, EDGE_EMBEDDING_DIM): """Embeds the edge features into a vector Args: edge_feature_columns (list(Layers)): A list of layers with edge feaures with shape (NUM_EDGES) EDGE_EMBEDDING_DIM (int): The embedding dimensionality of the edge feature vector Returns: (Layer): A layer containing the embedded edge feature matrix of shape (NUM_EDGES, EDGE_EMBEDDING_DIM) """ # Courtesy of OGB bond_feature_dims = [5, 6, 2] _fan_in = bond_feature_dims[0] _fan_out = EDGE_EMBEDDING_DIM _embedding_weights = lbann.Weights( initializer=_xavier_uniform_init(_fan_in, _fan_out), name="bond_encoder_weights_{}".format(0)) temp = lbann.Embedding(edge_feature_columns[0], num_embeddings=bond_feature_dims[0], embedding_dim=EDGE_EMBEDDING_DIM, weights=_embedding_weights, name="Bond_Embedding_0") for i in range(1, 3): _fan_in = bond_feature_dims[i] _fan_out = EDGE_EMBEDDING_DIM _embedding_weights = lbann.Weights( initializer=_xavier_uniform_init(_fan_in, _fan_out), name="bond_encoder_weights_{}".format(i)) _temp2 = lbann.Embedding(edge_feature_columns[i], num_embeddings=bond_feature_dims[i], embedding_dim=EDGE_EMBEDDING_DIM, weights=_embedding_weights, name="Bond_Embedding_{}".format(i)) temp = lbann.Sum(temp, _temp2) return temp
def forward(self, node_features, neighbor_features, edge_features, edge_index): """Apply NNConv layer. Args: node_features (Layer): A 2D layer of node features of shape (num_nodes, input_channels) neighbor_features (Layer): A 3D layer of node features of shape (num_edges, 1, input_channels) edge_features (Layer): A 2D layer of edge features of shape (num_edges, edge_features) edge_index (Layer): A 1D layer of node features of shape (num_edges * output_channels). The indices used for reduction Returns: (Layer): The output after NNConv. The output layer has the shape (num_nodes, self.output_channels) """ updated_node_fts, neighbor_vals = self.message(node_features, neighbor_features, edge_features) aggregated_fts = self.aggregate(neighbor_vals, edge_index) update = lbann.Sum(updated_node_fts, aggregated_fts, name=self.name+"_updated_node_features") return update
def forward(self, X, A): messages = lbann.MatMul(X, self.W2, name=self.name + '_w2_mult') messages = lbann.MatMul(A, messages, name=self.name + '_adj_mult') ident = lbann.MatMul(X, self.W1, name=self.name + '_w1_mult') out = lbann.Sum(ident, messages, name=self.name + '_sum_id') return out
def forward(self, X, A): """Apply Graph Conv Layer to X and use A for message passing Args: X (GraphVertexData): LBANN Data object, which is a collection of Layers. Each Layer is of the shape (1,input_channels) A (Layer): Adjacency matrix input with shape (num_nodes, num_nodes) Returns: GraphVertexData: The output after convolution. The output can passed into another Graph Conv layer directly """ # Accumulate Messages from Neighboring Nodes out = X.get_mat() out = lbann.MatMul(out, self.weights1, name=self.name + "_Graph_MATMUL") message = lbann.MatMul(A, out, name=self.name + "_Graph_Message") message = GraphVertexData.matrix_to_graph(message, X.shape[0], self.output_channels) # Assume X is a GraphVertexData object for node_feature in range(X.shape[0]): X[node_feature] = lbann.MatMul(X[node_feature], self.weights2) for node_feature in range(X.shape[0]): if (self.bias): message[node_feature] = lbann.Sum( message[node_feature], self.bias, name=self.name + '_message_bias_' + str(node_feature)) X[node_feature] = lbann.Sum(X[node_feature], message[node_feature]) if self.activation: for node_feature in range(X.shape[0]): X[node_feature] = self.activation(X[node_feature]) X.update_num_features(self.output_channels) return X
def forward(self, _): w = lbann.WeightsLayer(weights=self.weights, dims='%d %d'.format(self.width, self.height)) slice = lbann.Slice(w, axis=0, slice_points=' '.join(range(self.width + 1))) cols = [] for _ in range(self.width): cols.append(lbann.Sqrt(lbann.L2Norm2(slice))) return lbann.Sum(cols)
def random_projection(indices, num_projections, projection_dim): # Expand input indices to get an index for each vector entry # Note: proj_indices(i) = index*projection_dim + i proj_indices = lbann.WeightedSum( indices, scaling_factors=utils.str_list(projection_dim), ) iota = lbann.WeightsLayer( dims=utils.str_list(projection_dim), weights=lbann.Weights( initializer=lbann.ValueInitializer( values=utils.str_list(range(projection_dim))), optimizer=lbann.NoOptimizer(), ), ) proj_indices = lbann.Sum( lbann.Tessellate( lbann.Reshape(proj_indices, dims=utils.str_list([num_projections, 1])), dims=utils.str_list([num_projections, projection_dim]), ), lbann.Tessellate( lbann.Reshape(iota, dims=utils.str_list([1, projection_dim])), dims=utils.str_list([num_projections, projection_dim]), ), ) # Apply hash function and convert to Gaussian distribution proj = lbann.UniformHash(proj_indices) ones = lbann.Constant( value=1, num_neurons=utils.str_list([num_projections, projection_dim]), ) eps = 0.001 proj = lbann.ErfInv( lbann.WeightedSum( proj, ones, scaling_factors=utils.str_list([2 * (1 - eps), -(1 - eps)]), )) proj = lbann.InstanceNorm(proj) proj = lbann.WeightedSum( proj, scaling_factors=utils.str_list(1 / projection_dim), ) return proj
def forward(self, inputs): raise NotImplementedError # Requires log-gamma function if len(inputs) != 2: raise ValueError('expected two inputs: predictions and labels') pred = inputs[0] label = inputs[1] count = lbann.Reduction(label) alpha_sum = lbann.Reduction(pred) lgamma_alpha_sum = lbann.Reduction(lbann.LogGamma(pred)) lgamma_alpha_level_count_sum = lbann.Reduction( lbann.LogGamma(lbann.Add([pred, label]))) return lbann.WeightedSum([ lbann.LogGamma(alpha_sum), lbann.LogGamma(lbann.Sum([count, alpha_sum])), lgamma_alpha_level_count, lgamma_alpha_sum ], scaling_factors='-1.0 1.0 -1.0 1.0')
def forward(self, node_feature_mat, source_indices, target_indices, activation=lbann.Relu): """Apply GIN Layer. Args: node_feature_mat (Layer): Node feature matrix with the shape of (num_nodes,input_channels) source_indices (Layer): Source node indices of the edges with shape (num_nodes) target_indices (Layer): Target node indices of the edges with shape (num_nodes activation (Layer): Activation layer for the node features. If None, then no activation is applied. (default: lbann.Relu) Returns: (Layer) : The output after kernel ops. The output can passed into another Graph Conv layer directly """ eps = lbann.Constant(value=(1 + self.eps), num_neurons=str_list( [self.num_nodes, self.input_channel_size])) eps_node_features = lbann.Multiply(node_feature_mat, eps, name=self.name + "_epl_mult") node_feature_mat = lbann.Sum(eps_node_features, node_feature_mat) # Transform with the sequence of linear layers for layer in self.nn: node_feature_mat = layer(node_feature_mat) neighborhoods = GraphExpand(node_feature_mat, target_indices) neighborhoods = lbann.Reshape( neighborhoods, dims=str_list([self.num_edges, self.output_channel_size])) aggregated_node_features = GraphReduce( neighborhoods, source_indices, [self.num_nodes, self.output_channel_size]) ## Apply activation if activation: aggregated_node_features = activation(aggregated_node_features) return aggregated_node_features
def forward(self, X, A, activation = lbann.Relu): """Apply GIN Layer. Args: X (GraphVertexData): LBANN Data object, which is a collection of Layers. Each Layer is of the shape (1,input_channels) A (Layer): Adjacency matrix input with shape (num_nodes, num_nodes) activation (Layer): Activation layer for the node features. If None, then no activation is applied. (default: lbann.Relu) Returns: (GraphVertexData): The output after GCN. The output can passed into another Graph Conv layer directly """ in_channel = X.shape[1] # Accumulate Messages from Neighboring Nodes out = X.get_mat() out = lbann.MatMul(A,out, name = self.name+"_GIN_MATMUL") message = GraphVertexData.matrix_to_graph(out, X.shape[0], in_channel) # Aggregate Messages into node features eps = lbann.Constant(value=(1+self.eps),num_neurons = str_list([1, in_channel])) for node_feature in range(X.shape[0]): eps_val = lbann.Multiply(eps, X[node_feature]) X[node_feature] = lbann.Sum(message[node_feature], eps_val) # Transform with the sequence of linear layers for layer in self.nn: for node_feature in range(X.shape[0]): X[node_feature] = layer(X[node_feature]) ## Apply activation if activation: for node_feature in range(X.shape[0]): X[node_feature] = activation(X[node_feature]) X.update_num_features(self.output_channels) return X
def forward(self, node_features_mat, edge_features_tensor, node_features_tensor, adjacency_tensor): num_edges = self.num_nodes**2 edge_ft_shape = str_list( [num_edges, self.input_channels, self.output_channels]) node_ft_tensor_shape = str_list( [self.num_nodes, self.num_nodes, self.output_channels]) node_ft_mat_shape = str_list([self.num_nodes, self.output_channels]) transformed_edge_ft_tensor = None for layer in self.edge_nn: if transformed_edge_ft_tensor is not None: transformed_edge_ft_tensor = layer(transformed_edge_ft_tensor) else: transformed_edge_ft_tensor = layer(edge_features_tensor) transformed_edge_ft_tensor = lbann.Reshape(transformed_edge_ft_tensor, dims=edge_ft_shape, name=self.name + "_edge_ft_reshape") new_node_features = lbann.MatMul(node_features_tensor, transformed_edge_ft_tensor) new_node_features = lbann.Reshape(new_node_features, dims=node_ft_tensor_shape) gathered_node_features = lbann.MatMul(adjacency_tensor, new_node_features) new_node_features = lbann.Reshape(gathered_node_features, dims=node_ft_mat_shape) updated_nodes = self.node_nn(node_features_mat) out = lbann.Sum(new_node_features, updated_nodes) return out
def forward(self, X, A): """Apply GCN Args: X (GraphVertexData): LBANN Data object, which is a collection of Layers. Each Layer is of the shape (1,input_channels) A (Layer): Adjacency matrix input with shape (num_nodes, num_nodes) applied. The adjacency matrix is assumed to be normalized in the pre-processing step. Returns: LBANN_Data_Mat: The output after GCN. The output can passed into another Graph Conv layer directly """ # Assume X is a lbann data object for i in range(X.shape[0]): X[i] = lbann.MatMul(X[i], self.W, name=self.name + '_message_' + str(i)) if (self.bias): X[i] = lbann.Sum(X[i], self.bias, name=self.name + '_message_bias_' + str(i)) # Pass Message to Node Features out = X.get_mat(self.output_channels) # A - adjacency matrix is assumed to be normalized such that # A = D^-0.5 A D^0.5 as the convention in the GCN paper. out = lbann.MatMul(A, out, name=self.name + '_aggregate') if self.activation: out = self.activation(out) out = GraphVertexData.matrix_to_graph(out, X.shape[0], self.output_channels) return out
def AtomEncoder(node_feature_columns, EMBEDDING_DIM): """Embeds the node features into a vector Args: edge_feature_columns (list(Layers)): A list of layers with node feaures with shape (NUM_NODES) EMBEDDING_DIM (int): The embedding dimensionality of the node feature vector Returns: (Layer): A layer containing the embedded node feature matrix of shape (NUM_NODES, EMBEDDING_DIM) """ # Courtesy of OGB atom_feature_dims = [119, 4, 12, 12, 10, 6, 6, 2, 2] _fan_in = atom_feature_dims[0] _fan_out = EMBEDDING_DIM _embedding_weights = lbann.Weights( initializer=_xavier_uniform_init(_fan_in, _fan_out), name="atom_encoder_weights_{}".format(0)) temp = lbann.Embedding(node_feature_columns[0], num_embeddings=atom_feature_dims[0], embedding_dim=EMBEDDING_DIM, weights=_embedding_weights, name="Atom_Embedding_0") for i in range(1, 9): _fan_in = atom_feature_dims[i] _fan_out = EMBEDDING_DIM _embedding_weights = lbann.Weights( initializer=_xavier_uniform_init(_fan_in, _fan_out), name="atom_encoder_weights_{}".format(i)) _temp2 = lbann.Embedding(node_feature_columns[i], num_embeddings=atom_feature_dims[i], embedding_dim=EMBEDDING_DIM, weights=_embedding_weights, name="Atom_Embedding_{}".format(i)) temp = lbann.Sum(temp, _temp2) return temp
def forward(self, node_feature_mat, source_indices, target_indices): """Apply Graph Conv Layer Args: node_feature_mat (Layer): Node feature matrix with the shape of (num_nodes,input_channels) source_indices (Layer): Source node indices of the edges with shape (num_nodes) target_indices (Layer): Target node indices of the edges with shape (num_nodes) Returns: (Layer) : The output after kernel ops. The output can passed into another Graph Conv layer directly """ new_self_features = self.id_nn(node_feature_mat) new_neighbor_features = self.mat_nn(node_feature_mat) # Place the new features on to neighborhoods neighborhoods = GraphExpand(new_neighbor_features, target_indices) # Accumulate Messages from Neighboring Nodes reduced_features = GraphReduce( neighborhoods, source_indices, [self.num_nodes, self.output_channel_size]) out_features = lbann.Sum(new_self_features, reduced_features) return out_features
def forward(self, queries, keys, values, mask=None): """Apply multi-head attention. The input and output tensors are interpreted as sequences of vectors, where the first tensor dimension is the sequence dimension. Args: queries (lbann.Layer): Sequence of query vectors. keys (lbann.Layer): Sequence of key vectors. values (lbann.Layer): Sequence of value vectors. mask (lbann.Layer, optional): Additive attention mask. If the (i,j) entry is very negative (e.g. -1e9), then the ith query does not attend to the jth key/value pair. Returns: lbann.Layer: Sequence of output vectors. The sequence length is the same as `queries`. """ ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH BRANCHES = self.BRANCHES if (ENABLE_SUBGRAPH): if (self.num_heads % BRANCHES != 0): raise ValueError('Num heads should be divisible by BRANCHES') self.instance += 1 name = f'{self.name}_instance{self.instance}' # Apply fully-connected layers to input sequences queries_fc = lbann.ChannelwiseFullyConnected( queries, weights=self.query_weights, output_channel_dims=[self.inner_dim], name=f'{name}_queries_fc', ) keys_fc = lbann.ChannelwiseFullyConnected( keys, weights=self.key_weights, output_channel_dims=[self.inner_dim], name=f'{name}_keys_fc', ) values_fc = lbann.ChannelwiseFullyConnected( values, weights=self.value_weights, output_channel_dims=[self.inner_dim], name=f'{name}_values_fc', ) # Slice embedding vectors for each head slice_points = str_list(self.head_dim * i for i in range(self.num_heads + 1)) queries_slice = lbann.Slice(queries_fc, axis=1, slice_points=slice_points, name=f'{name}_queries_slice', parallel_strategy={ 'sub_branch_tag': 0, 'enable_subgraph': ENABLE_SUBGRAPH }) keys_slice = lbann.Slice(keys_fc, axis=1, slice_points=slice_points, name=f'{name}_keys_slice', parallel_strategy={ 'sub_branch_tag': 0, 'enable_subgraph': ENABLE_SUBGRAPH }) values_slice = lbann.Slice(values_fc, axis=1, slice_points=slice_points, name=f'{name}_values_slice', parallel_strategy={ 'sub_branch_tag': 0, 'enable_subgraph': ENABLE_SUBGRAPH }) # Compute scaled dot-product attention for each head attentions = [] tag = 0 for head in range(self.num_heads): head_name = f'{name}_myattention_head{head}' # Attention inputs if (ENABLE_SUBGRAPH): if (head % int(self.num_heads / BRANCHES) == 0): tag += 1 q = lbann.Identity(queries_slice, parallel_strategy={ 'sub_branch_tag': tag, 'enable_subgraph': ENABLE_SUBGRAPH }) k = lbann.Identity(keys_slice, parallel_strategy={ 'sub_branch_tag': tag, 'enable_subgraph': ENABLE_SUBGRAPH }) v = lbann.Identity(values_slice, parallel_strategy={ 'sub_branch_tag': tag, 'enable_subgraph': ENABLE_SUBGRAPH }) else: q = lbann.Identity(queries_slice) k = lbann.Identity(keys_slice) v = lbann.Identity(values_slice) # Multiply queries and keys # Note: num_queries x num_keys y = lbann.MatMul( q, k, transpose_b=True, name=f'{head_name}_matmul', ) y = lbann.WeightedSum( y, scaling_factors=str(1 / math.sqrt(self.head_dim)), name=f'{head_name}_scale', ) if (ENABLE_SUBGRAPH): if mask != None: y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask') else: if mask: y = lbann.Sum([y, mask], name=f'{head_name}_mask') y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax') # Attention output # Note: num_queries x head_dim attentions.append(lbann.MatMul(y, v, name=head_name)) #Strong scaling # Concatenate heads and apply fully-connected layer if (ENABLE_SUBGRAPH): attentions = lbann.Concatenation(attentions, axis=1, name=f'{name}_heads_concat', parallel_strategy={ 'sub_branch_tag': 0, 'enable_subgraph': ENABLE_SUBGRAPH }) else: attentions = lbann.Concatenation( attentions, axis=1, name=f'{name}_heads_concat', ) outputs_fc = lbann.ChannelwiseFullyConnected( attentions, weights=self.output_weights, output_channel_dims=[self.embed_dim], name=f'{name}', ) return outputs_fc
def forward(self, queries, keys, values, mask=None): """Apply multi-head attention. The input and output tensors are interpreted as sequences of vectors, where the first tensor dimension is the sequence dimension. Args: queries (lbann.Layer): Sequence of query vectors. keys (lbann.Layer): Sequence of key vectors. values (lbann.Layer): Sequence of value vectors. mask (lbann.Layer, optional): Additive attention mask. If the (i,j) entry is very negative (e.g. -1e9), then the ith query does not attend to the jth key/value pair. Returns: lbann.Layer: Sequence of output vectors. The sequence length is the same as `queries`. """ ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH BRANCHES = self.BRANCHES if (ENABLE_SUBGRAPH): if (self.num_heads % BRANCHES != 0): raise ValueError('Num heads should be divisible by BRANCHES') self.instance += 1 name = f'{self.name}_instance{self.instance}' # Apply fully-connected layers to input sequences queries_fc = [] keys_fc = [] values_fc = [] # Slice embedding vectors for each head slice_points = str_list( self.head_dim * i for i in range(int(self.num_heads / self.BRANCHES) + 1)) #Queries strong scaling in CFC attentions = [] for count, query in enumerate(queries): temp = lbann.ChannelwiseFullyConnected( query, weights=self.query_weights[count], output_channel_dims=[self.inner_dim], name=f'{name}_subgrid{count}_queries_fc', ) attentions.append(temp) grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions) attentions = [] for head in range(self.BRANCHES): attentions.append(lbann.Identity(grid_sum_slice)) for head in range(self.BRANCHES): temp = lbann.Slice( attentions[head], axis=1, slice_points=slice_points, name=f'{name}_subgrid{head}_queries_slice', ) queries_fc.append(temp) #keys strong scaling in CFC attentions = [] for count, key in enumerate(keys): temp = lbann.ChannelwiseFullyConnected( key, weights=self.key_weights[count], output_channel_dims=[self.inner_dim], name=f'{name}_subgrid{count}_keys_fc', ) attentions.append(temp) grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions) attentions = [] for head in range(self.BRANCHES): attentions.append(lbann.Identity(grid_sum_slice)) for head in range(self.BRANCHES): temp = lbann.Slice( attentions[head], axis=1, slice_points=slice_points, name=f'{name}_subgrid{head}_keys_slice', ) keys_fc.append(temp) #Values strong scaling in CFC attentions = [] for count, value in enumerate(values): temp = lbann.ChannelwiseFullyConnected( value, weights=self.value_weights[count], output_channel_dims=[self.inner_dim], name=f'{name}_subgrid{count}_values_fc', ) attentions.append(temp) grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions) attentions = [] for head in range(self.BRANCHES): attentions.append(lbann.Identity(grid_sum_slice)) for head in range(self.BRANCHES): temp = lbann.Slice( attentions[head], axis=1, slice_points=slice_points, name=f'{name}_subgrid{head}_values_slice', ) values_fc.append(temp) queries_slice = [] keys_slice = [] values_slice = [] for branch in range(self.BRANCHES): querie_slice = queries_fc[branch] key_slice = keys_fc[branch] value_slice = values_fc[branch] for head in range(int(self.num_heads / self.BRANCHES)): queries_slice.append(lbann.Identity(querie_slice)) keys_slice.append(lbann.Identity(key_slice)) values_slice.append(lbann.Identity(value_slice)) # Compute scaled dot-product attention for each head attentions = [] #variable to combine heads locally in sub-grids temp_attentions = [] tag = 0 for head in range(self.num_heads): head_name = f'{name}_myattention_head{head}' # Attention inputs if (head % int(self.num_heads / BRANCHES) == 0): temp_attentions.append([]) tag += 1 q = lbann.Identity(queries_slice[head]) k = lbann.Identity(keys_slice[head]) v = lbann.Identity(values_slice[head]) # Multiply queries and keys # Note: num_queries x num_keys y = lbann.MatMul( q, k, transpose_b=True, name=f'{head_name}_matmul', ) y = lbann.WeightedSum( y, scaling_factors=str(1 / math.sqrt(self.head_dim)), name=f'{head_name}_scale', ) if (ENABLE_SUBGRAPH): if mask != None: y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask') else: if mask: y = lbann.Sum([y, mask], name=f'{head_name}_mask') y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax') # Attention output # Note: num_queries x head_dim y = lbann.MatMul(y, v, name=head_name) # attentions.append(lbann.MatMul(y, v, name=head_name)) temp_attentions[-1].append(y) for count, temp_attention in enumerate(temp_attentions): if (self.BRANCHES == self.num_heads): # No need to concat the heads at subgrid level # if number of subgrids is equal to number of heads attention_single_subgrid = temp_attentions[count][0] else: attention_single_subgrid = lbann.Concatenation( temp_attention, axis=1, name=f'{name}_subgrid_heads_concat{count}', parallel_strategy={ 'sub_branch_tag': 0, 'enable_subgraph': False }) attention_single_subgrid = lbann.ChannelwiseFullyConnected( attention_single_subgrid, weights=self.output_weights[count], output_channel_dims=[self.embed_dim], name=f'{name}_cfc_{count}', ) attentions.append(attention_single_subgrid) #Strong scaling grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions) attentions = [] for head in range(self.BRANCHES): attentions.append(lbann.Identity(grid_sum_slice)) return attentions
def forward(self, x, memory, src_mask=None, tgt_mask=None): """Apply Transformer decoder layer. Args: x (lbann.Layer): Sequence of input vectors. memory (lbann.Layer): Sequence of vectors produced by Transformer encoder stack. src_mask (lbann.Layer, optional): Attention mask for second attention module (attends to both `x` and `memory`). tgt_mask (lbann.Layer, optional): Attention mask for first attention module (attends only to `x`). Returns: lbann.Layer: Sequence of output vectors. """ self.instance += 1 name = f'{self.name}_instance{self.instance}' # Self-attention with residual connection y = self.attention1(x, x, x, mask=tgt_mask) if self.dropout_prob > 0: y = lbann.Dropout( y, keep_prob=1 - self.dropout_prob, name=f'{name}_drop1', ) z = lbann.Sum(x, y, name=f'{name}_sum1') z = lbann.InstanceNorm(z, name=f'{name}_norm1') x = z # Attention on encoder output with residual connection y = self.attention2(x, memory, memory, mask=src_mask) if self.dropout_prob > 0: y = lbann.Dropout( y, keep_prob=1 - self.dropout_prob, name=f'{name}_drop2', ) z = lbann.Sum(x, y, name=f'{name}_sum2') z = lbann.InstanceNorm(z, name=f'{name}_norm2') x = z # Feedforward network with residual connection y = lbann.ChannelwiseFullyConnected( x, weights=self.fc1_weights, output_channel_dims=[self.feedforward_dim], name=f'{name}_fc1', ) y = lbann.Relu(y, name=f'{name}_relu1') if self.dropout_prob > 0: y = lbann.Dropout( y, keep_prob=1 - self.dropout_prob, name=f'{name}_drop3', ) y = lbann.ChannelwiseFullyConnected( y, weights=self.fc2_weights, output_channel_dims=[self.embed_dim], name=f'{name}_fc2', ) if self.dropout_prob > 0: y = lbann.Dropout( y, keep_prob=1 - self.dropout_prob, name=f'{name}_drop4', ) z = lbann.Sum(x, y, name=f'{name}_sum3') z = lbann.InstanceNorm(z, name=f'{name}_norm3') return z