Ejemplo n.º 1
0
    def forward(self, x, mask=None):
        """Apply Transformer encoder layer.

        Args:
            x (lbann.Layer): Sequence of input vectors.
            mask (lbann.Layer, optional): Attention mask.

        Returns:
            lbann.Layer: Sequence of output vectors.

        """
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Self-attention with residual connection
        y = self.attention(x, x, x, mask=mask)
        if self.dropout_prob > 0:
            y = lbann.Dropout(
                y,
                keep_prob=1 - self.dropout_prob,
                name=f'{name}_drop1',
            )
        z = lbann.Sum(x, y, name=f'{name}_sum1')
        z = lbann.InstanceNorm(z, name=f'{name}_norm1')
        x = z

        # Feedforward network with residual connection
        y = lbann.ChannelwiseFullyConnected(
            x,
            weights=self.fc1_weights,
            output_channel_dims=[self.feedforward_dim],
            name=f'{name}_fc1',
        )
        y = lbann.Relu(y, name=f'{name}_relu1')
        if self.dropout_prob > 0:
            y = lbann.Dropout(
                y,
                keep_prob=1 - self.dropout_prob,
                name=f'{name}_drop2',
            )
        y = lbann.ChannelwiseFullyConnected(
            y,
            weights=self.fc2_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_fc2',
        )
        if self.dropout_prob > 0:
            y = lbann.Dropout(
                y,
                keep_prob=1 - self.dropout_prob,
                name=f'{name}_drop3',
            )
        z = lbann.Sum(x, y, name=f'{name}_sum2')
        z = lbann.InstanceNorm(z, name=f'{name}_norm2')
        return z
Ejemplo n.º 2
0
def BondEncoder(edge_feature_columns, EDGE_EMBEDDING_DIM):
    """Embeds the edge features into a vector
	Args:
		edge_feature_columns (list(Layers)): A list of layers with edge feaures with shape (NUM_EDGES)
		EDGE_EMBEDDING_DIM (int): The embedding dimensionality of the edge feature vector
	Returns:
		(Layer): A layer containing the embedded edge feature matrix of shape (NUM_EDGES, EDGE_EMBEDDING_DIM)
		"""
    # Courtesy of OGB
    bond_feature_dims = [5, 6, 2]
    _fan_in = bond_feature_dims[0]
    _fan_out = EDGE_EMBEDDING_DIM
    _embedding_weights = lbann.Weights(
        initializer=_xavier_uniform_init(_fan_in, _fan_out),
        name="bond_encoder_weights_{}".format(0))

    temp = lbann.Embedding(edge_feature_columns[0],
                           num_embeddings=bond_feature_dims[0],
                           embedding_dim=EDGE_EMBEDDING_DIM,
                           weights=_embedding_weights,
                           name="Bond_Embedding_0")

    for i in range(1, 3):
        _fan_in = bond_feature_dims[i]
        _fan_out = EDGE_EMBEDDING_DIM
        _embedding_weights = lbann.Weights(
            initializer=_xavier_uniform_init(_fan_in, _fan_out),
            name="bond_encoder_weights_{}".format(i))
        _temp2 = lbann.Embedding(edge_feature_columns[i],
                                 num_embeddings=bond_feature_dims[i],
                                 embedding_dim=EDGE_EMBEDDING_DIM,
                                 weights=_embedding_weights,
                                 name="Bond_Embedding_{}".format(i))
        temp = lbann.Sum(temp, _temp2)
    return temp
Ejemplo n.º 3
0
    def forward(self,
                node_features,
                neighbor_features,
                edge_features,
                edge_index):
        """Apply NNConv layer.
        Args:
            node_features (Layer): A 2D layer of node features of
                                   shape (num_nodes, input_channels)
            neighbor_features (Layer): A 3D layer of node features of
                                       shape (num_edges, 1, input_channels)
            edge_features (Layer): A 2D layer of edge features of
                                   shape (num_edges, edge_features)
            edge_index (Layer): A 1D layer of node features of
                                shape (num_edges * output_channels).
                                The indices used for reduction
        Returns:
            (Layer): The output after NNConv. The output layer has the shape
                     (num_nodes, self.output_channels)
        """

        updated_node_fts, neighbor_vals = self.message(node_features,
                                                       neighbor_features,
                                                       edge_features)
        aggregated_fts = self.aggregate(neighbor_vals, edge_index)

        update = lbann.Sum(updated_node_fts,
                           aggregated_fts,
                           name=self.name+"_updated_node_features")

        return update
Ejemplo n.º 4
0
    def forward(self, X, A):
        messages = lbann.MatMul(X, self.W2, name=self.name + '_w2_mult')
        messages = lbann.MatMul(A, messages, name=self.name + '_adj_mult')

        ident = lbann.MatMul(X, self.W1, name=self.name + '_w1_mult')

        out = lbann.Sum(ident, messages, name=self.name + '_sum_id')

        return out
Ejemplo n.º 5
0
    def forward(self, X, A):
        """Apply Graph Conv Layer to X and use A for message passing

        Args:
            X (GraphVertexData): LBANN Data object, which is a collection of Layers. Each Layer is of
                                 the shape (1,input_channels) 

            A (Layer): Adjacency matrix input with shape (num_nodes, num_nodes)

        Returns: 
            
            GraphVertexData: The output after convolution. The output can passed into another Graph Conv layer
                          directly
        """

        # Accumulate Messages from Neighboring Nodes
        out = X.get_mat()
        out = lbann.MatMul(out,
                           self.weights1,
                           name=self.name + "_Graph_MATMUL")
        message = lbann.MatMul(A, out, name=self.name + "_Graph_Message")
        message = GraphVertexData.matrix_to_graph(message, X.shape[0],
                                                  self.output_channels)

        # Assume X is a GraphVertexData object

        for node_feature in range(X.shape[0]):
            X[node_feature] = lbann.MatMul(X[node_feature], self.weights2)

        for node_feature in range(X.shape[0]):
            if (self.bias):
                message[node_feature] = lbann.Sum(
                    message[node_feature],
                    self.bias,
                    name=self.name + '_message_bias_' + str(node_feature))
            X[node_feature] = lbann.Sum(X[node_feature], message[node_feature])

        if self.activation:
            for node_feature in range(X.shape[0]):
                X[node_feature] = self.activation(X[node_feature])

        X.update_num_features(self.output_channels)
        return X
Ejemplo n.º 6
0
 def forward(self, _):
     w = lbann.WeightsLayer(weights=self.weights,
                            dims='%d %d'.format(self.width, self.height))
     slice = lbann.Slice(w,
                         axis=0,
                         slice_points=' '.join(range(self.width + 1)))
     cols = []
     for _ in range(self.width):
         cols.append(lbann.Sqrt(lbann.L2Norm2(slice)))
     return lbann.Sum(cols)
Ejemplo n.º 7
0
def random_projection(indices, num_projections, projection_dim):

    # Expand input indices to get an index for each vector entry
    # Note: proj_indices(i) = index*projection_dim + i
    proj_indices = lbann.WeightedSum(
        indices,
        scaling_factors=utils.str_list(projection_dim),
    )
    iota = lbann.WeightsLayer(
        dims=utils.str_list(projection_dim),
        weights=lbann.Weights(
            initializer=lbann.ValueInitializer(
                values=utils.str_list(range(projection_dim))),
            optimizer=lbann.NoOptimizer(),
        ),
    )
    proj_indices = lbann.Sum(
        lbann.Tessellate(
            lbann.Reshape(proj_indices,
                          dims=utils.str_list([num_projections, 1])),
            dims=utils.str_list([num_projections, projection_dim]),
        ),
        lbann.Tessellate(
            lbann.Reshape(iota, dims=utils.str_list([1, projection_dim])),
            dims=utils.str_list([num_projections, projection_dim]),
        ),
    )

    # Apply hash function and convert to Gaussian distribution
    proj = lbann.UniformHash(proj_indices)
    ones = lbann.Constant(
        value=1,
        num_neurons=utils.str_list([num_projections, projection_dim]),
    )
    eps = 0.001
    proj = lbann.ErfInv(
        lbann.WeightedSum(
            proj,
            ones,
            scaling_factors=utils.str_list([2 * (1 - eps), -(1 - eps)]),
        ))
    proj = lbann.InstanceNorm(proj)
    proj = lbann.WeightedSum(
        proj,
        scaling_factors=utils.str_list(1 / projection_dim),
    )
    return proj
Ejemplo n.º 8
0
 def forward(self, inputs):
     raise NotImplementedError  # Requires log-gamma function
     if len(inputs) != 2:
         raise ValueError('expected two inputs: predictions and labels')
     pred = inputs[0]
     label = inputs[1]
     count = lbann.Reduction(label)
     alpha_sum = lbann.Reduction(pred)
     lgamma_alpha_sum = lbann.Reduction(lbann.LogGamma(pred))
     lgamma_alpha_level_count_sum = lbann.Reduction(
         lbann.LogGamma(lbann.Add([pred, label])))
     return lbann.WeightedSum([
         lbann.LogGamma(alpha_sum),
         lbann.LogGamma(lbann.Sum([count, alpha_sum])),
         lgamma_alpha_level_count, lgamma_alpha_sum
     ],
                              scaling_factors='-1.0 1.0 -1.0 1.0')
Ejemplo n.º 9
0
    def forward(self,
                node_feature_mat,
                source_indices,
                target_indices,
                activation=lbann.Relu):
        """Apply GIN  Layer. 
        
        Args:
            node_feature_mat (Layer): Node feature matrix with the shape of (num_nodes,input_channels) 
            source_indices (Layer): Source node indices of the edges with shape (num_nodes)
            target_indices (Layer): Target node indices of the edges with shape (num_nodes
            activation (Layer): Activation layer for the node features. If None, then no activation is 
                                applied. (default: lbann.Relu) 
        Returns: 
            (Layer) : The output after kernel ops. The output can passed into another Graph Conv layer
                          directly
        """
        eps = lbann.Constant(value=(1 + self.eps),
                             num_neurons=str_list(
                                 [self.num_nodes, self.input_channel_size]))

        eps_node_features = lbann.Multiply(node_feature_mat,
                                           eps,
                                           name=self.name + "_epl_mult")

        node_feature_mat = lbann.Sum(eps_node_features, node_feature_mat)

        # Transform with the sequence of linear layers
        for layer in self.nn:
            node_feature_mat = layer(node_feature_mat)

        neighborhoods = GraphExpand(node_feature_mat, target_indices)

        neighborhoods = lbann.Reshape(
            neighborhoods,
            dims=str_list([self.num_edges, self.output_channel_size]))

        aggregated_node_features = GraphReduce(
            neighborhoods, source_indices,
            [self.num_nodes, self.output_channel_size])
        ## Apply activation
        if activation:
            aggregated_node_features = activation(aggregated_node_features)

        return aggregated_node_features
Ejemplo n.º 10
0
    def forward(self, X, A, activation = lbann.Relu):
        """Apply GIN  Layer. 
        
        Args:
            X (GraphVertexData): LBANN Data object, which is a collection of Layers. Each Layer is of
                                 the shape (1,input_channels) 

            A (Layer): Adjacency matrix input with shape (num_nodes, num_nodes)

            activation (Layer): Activation layer for the node features. If None, then no activation is 
                                applied. (default: lbann.Relu) 
        Returns: 
            
            (GraphVertexData): The output after GCN. The output can passed into another Graph Conv layer
                          directly
        """
        in_channel = X.shape[1]

        # Accumulate Messages from Neighboring Nodes
        out = X.get_mat()
        out = lbann.MatMul(A,out, name = self.name+"_GIN_MATMUL")
        message = GraphVertexData.matrix_to_graph(out, X.shape[0], in_channel)

        # Aggregate Messages into node features  
        eps = lbann.Constant(value=(1+self.eps),num_neurons = str_list([1, in_channel]))
        for node_feature in range(X.shape[0]):
            eps_val = lbann.Multiply(eps, X[node_feature])
            X[node_feature] = lbann.Sum(message[node_feature], eps_val)
        
        # Transform with the sequence of linear layers
        for layer in self.nn:
            for node_feature in range(X.shape[0]):
                X[node_feature] = layer(X[node_feature])
        
        ## Apply activation 
        if activation:
            for node_feature in range(X.shape[0]):
                X[node_feature] = activation(X[node_feature])
        X.update_num_features(self.output_channels) 
        return X
Ejemplo n.º 11
0
    def forward(self, node_features_mat, edge_features_tensor,
                node_features_tensor, adjacency_tensor):

        num_edges = self.num_nodes**2

        edge_ft_shape = str_list(
            [num_edges, self.input_channels, self.output_channels])
        node_ft_tensor_shape = str_list(
            [self.num_nodes, self.num_nodes, self.output_channels])
        node_ft_mat_shape = str_list([self.num_nodes, self.output_channels])

        transformed_edge_ft_tensor = None

        for layer in self.edge_nn:
            if transformed_edge_ft_tensor is not None:
                transformed_edge_ft_tensor = layer(transformed_edge_ft_tensor)
            else:
                transformed_edge_ft_tensor = layer(edge_features_tensor)

        transformed_edge_ft_tensor = lbann.Reshape(transformed_edge_ft_tensor,
                                                   dims=edge_ft_shape,
                                                   name=self.name +
                                                   "_edge_ft_reshape")

        new_node_features = lbann.MatMul(node_features_tensor,
                                         transformed_edge_ft_tensor)
        new_node_features = lbann.Reshape(new_node_features,
                                          dims=node_ft_tensor_shape)

        gathered_node_features = lbann.MatMul(adjacency_tensor,
                                              new_node_features)

        new_node_features = lbann.Reshape(gathered_node_features,
                                          dims=node_ft_mat_shape)
        updated_nodes = self.node_nn(node_features_mat)

        out = lbann.Sum(new_node_features, updated_nodes)

        return out
Ejemplo n.º 12
0
    def forward(self, X, A):
        """Apply GCN

        Args:
            X (GraphVertexData): LBANN Data object, which is a collection of Layers. Each Layer is of
                                 the shape (1,input_channels) 
            A (Layer): Adjacency matrix input with shape (num_nodes, num_nodes)
                                applied. The adjacency matrix is assumed to be normalized in the 
                                pre-processing step. 
        Returns:     
            LBANN_Data_Mat: The output after GCN. The output can passed into another Graph Conv layer
                          directly
        """

        # Assume X is a lbann data object
        for i in range(X.shape[0]):
            X[i] = lbann.MatMul(X[i],
                                self.W,
                                name=self.name + '_message_' + str(i))
            if (self.bias):
                X[i] = lbann.Sum(X[i],
                                 self.bias,
                                 name=self.name + '_message_bias_' + str(i))

        # Pass Message to Node Features
        out = X.get_mat(self.output_channels)

        # A - adjacency matrix is assumed to be normalized such that
        # A = D^-0.5 A D^0.5 as the convention in the GCN paper.
        out = lbann.MatMul(A, out, name=self.name + '_aggregate')

        if self.activation:
            out = self.activation(out)

        out = GraphVertexData.matrix_to_graph(out, X.shape[0],
                                              self.output_channels)

        return out
Ejemplo n.º 13
0
def AtomEncoder(node_feature_columns, EMBEDDING_DIM):
    """Embeds the node features into a vector

	Args:
		edge_feature_columns (list(Layers)): A list of layers with node feaures with shape (NUM_NODES)
		EMBEDDING_DIM (int): The embedding dimensionality of the node feature vector
	Returns:
		(Layer): A layer containing the embedded node feature matrix of shape (NUM_NODES, EMBEDDING_DIM)
		"""
    # Courtesy of OGB
    atom_feature_dims = [119, 4, 12, 12, 10, 6, 6, 2, 2]

    _fan_in = atom_feature_dims[0]
    _fan_out = EMBEDDING_DIM

    _embedding_weights = lbann.Weights(
        initializer=_xavier_uniform_init(_fan_in, _fan_out),
        name="atom_encoder_weights_{}".format(0))

    temp = lbann.Embedding(node_feature_columns[0],
                           num_embeddings=atom_feature_dims[0],
                           embedding_dim=EMBEDDING_DIM,
                           weights=_embedding_weights,
                           name="Atom_Embedding_0")
    for i in range(1, 9):
        _fan_in = atom_feature_dims[i]
        _fan_out = EMBEDDING_DIM
        _embedding_weights = lbann.Weights(
            initializer=_xavier_uniform_init(_fan_in, _fan_out),
            name="atom_encoder_weights_{}".format(i))
        _temp2 = lbann.Embedding(node_feature_columns[i],
                                 num_embeddings=atom_feature_dims[i],
                                 embedding_dim=EMBEDDING_DIM,
                                 weights=_embedding_weights,
                                 name="Atom_Embedding_{}".format(i))
        temp = lbann.Sum(temp, _temp2)
    return temp
Ejemplo n.º 14
0
    def forward(self, node_feature_mat, source_indices, target_indices):
        """Apply Graph Conv Layer

        Args:
            node_feature_mat (Layer): Node feature matrix with the shape of (num_nodes,input_channels) 
            source_indices (Layer): Source node indices of the edges with shape (num_nodes)
            target_indices (Layer): Target node indices of the edges with shape (num_nodes)
        Returns:     
            (Layer) : The output after kernel ops. The output can passed into another Graph Conv layer
                          directly
        """

        new_self_features = self.id_nn(node_feature_mat)

        new_neighbor_features = self.mat_nn(node_feature_mat)
        # Place the new features on to neighborhoods
        neighborhoods = GraphExpand(new_neighbor_features, target_indices)
        # Accumulate Messages from Neighboring Nodes
        reduced_features = GraphReduce(
            neighborhoods, source_indices,
            [self.num_nodes, self.output_channel_size])

        out_features = lbann.Sum(new_self_features, reduced_features)
        return out_features
Ejemplo n.º 15
0
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH
        BRANCHES = self.BRANCHES
        if (ENABLE_SUBGRAPH):
            if (self.num_heads % BRANCHES != 0):
                raise ValueError('Num heads should be divisible by BRANCHES')
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = lbann.ChannelwiseFullyConnected(
            queries,
            weights=self.query_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_queries_fc',
        )
        keys_fc = lbann.ChannelwiseFullyConnected(
            keys,
            weights=self.key_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_keys_fc',
        )
        values_fc = lbann.ChannelwiseFullyConnected(
            values,
            weights=self.value_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_values_fc',
        )

        # Slice embedding vectors for each head
        slice_points = str_list(self.head_dim * i
                                for i in range(self.num_heads + 1))
        queries_slice = lbann.Slice(queries_fc,
                                    axis=1,
                                    slice_points=slice_points,
                                    name=f'{name}_queries_slice',
                                    parallel_strategy={
                                        'sub_branch_tag': 0,
                                        'enable_subgraph': ENABLE_SUBGRAPH
                                    })
        keys_slice = lbann.Slice(keys_fc,
                                 axis=1,
                                 slice_points=slice_points,
                                 name=f'{name}_keys_slice',
                                 parallel_strategy={
                                     'sub_branch_tag': 0,
                                     'enable_subgraph': ENABLE_SUBGRAPH
                                 })
        values_slice = lbann.Slice(values_fc,
                                   axis=1,
                                   slice_points=slice_points,
                                   name=f'{name}_values_slice',
                                   parallel_strategy={
                                       'sub_branch_tag': 0,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })

        # Compute scaled dot-product attention for each head
        attentions = []
        tag = 0
        for head in range(self.num_heads):
            head_name = f'{name}_myattention_head{head}'

            # Attention inputs

            if (ENABLE_SUBGRAPH):
                if (head % int(self.num_heads / BRANCHES) == 0):
                    tag += 1

                q = lbann.Identity(queries_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
                k = lbann.Identity(keys_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
                v = lbann.Identity(values_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
            else:
                q = lbann.Identity(queries_slice)
                k = lbann.Identity(keys_slice)
                v = lbann.Identity(values_slice)

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )

            if (ENABLE_SUBGRAPH):
                if mask != None:
                    y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask')
            else:
                if mask:
                    y = lbann.Sum([y, mask], name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim

            attentions.append(lbann.MatMul(y, v, name=head_name))

            #Strong scaling

        # Concatenate heads and apply fully-connected layer
        if (ENABLE_SUBGRAPH):
            attentions = lbann.Concatenation(attentions,
                                             axis=1,
                                             name=f'{name}_heads_concat',
                                             parallel_strategy={
                                                 'sub_branch_tag': 0,
                                                 'enable_subgraph':
                                                 ENABLE_SUBGRAPH
                                             })
        else:
            attentions = lbann.Concatenation(
                attentions,
                axis=1,
                name=f'{name}_heads_concat',
            )

        outputs_fc = lbann.ChannelwiseFullyConnected(
            attentions,
            weights=self.output_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}',
        )
        return outputs_fc
Ejemplo n.º 16
0
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH
        BRANCHES = self.BRANCHES
        if (ENABLE_SUBGRAPH):
            if (self.num_heads % BRANCHES != 0):
                raise ValueError('Num heads should be divisible by BRANCHES')
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = []
        keys_fc = []
        values_fc = []

        # Slice embedding vectors for each head
        slice_points = str_list(
            self.head_dim * i
            for i in range(int(self.num_heads / self.BRANCHES) + 1))

        #Queries strong scaling in CFC
        attentions = []
        for count, query in enumerate(queries):
            temp = lbann.ChannelwiseFullyConnected(
                query,
                weights=self.query_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_queries_fc',
            )
            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):
            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_queries_slice',
            )

            queries_fc.append(temp)

        #keys strong scaling in CFC

        attentions = []
        for count, key in enumerate(keys):
            temp = lbann.ChannelwiseFullyConnected(
                key,
                weights=self.key_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_keys_fc',
            )

            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):

            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_keys_slice',
            )

            keys_fc.append(temp)

        #Values strong scaling in CFC
        attentions = []

        for count, value in enumerate(values):
            temp = lbann.ChannelwiseFullyConnected(
                value,
                weights=self.value_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_values_fc',
            )
            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):
            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_values_slice',
            )
            values_fc.append(temp)

        queries_slice = []
        keys_slice = []
        values_slice = []

        for branch in range(self.BRANCHES):
            querie_slice = queries_fc[branch]
            key_slice = keys_fc[branch]
            value_slice = values_fc[branch]

            for head in range(int(self.num_heads / self.BRANCHES)):
                queries_slice.append(lbann.Identity(querie_slice))
                keys_slice.append(lbann.Identity(key_slice))
                values_slice.append(lbann.Identity(value_slice))

        # Compute scaled dot-product attention for each head
        attentions = []

        #variable to combine heads locally in sub-grids
        temp_attentions = []
        tag = 0
        for head in range(self.num_heads):
            head_name = f'{name}_myattention_head{head}'

            # Attention inputs
            if (head % int(self.num_heads / BRANCHES) == 0):
                temp_attentions.append([])
                tag += 1

            q = lbann.Identity(queries_slice[head])
            k = lbann.Identity(keys_slice[head])
            v = lbann.Identity(values_slice[head])

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )

            if (ENABLE_SUBGRAPH):
                if mask != None:
                    y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask')
            else:
                if mask:
                    y = lbann.Sum([y, mask], name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim
            y = lbann.MatMul(y, v, name=head_name)
            # attentions.append(lbann.MatMul(y, v, name=head_name))

            temp_attentions[-1].append(y)

        for count, temp_attention in enumerate(temp_attentions):

            if (self.BRANCHES == self.num_heads):
                # No need to concat the heads at subgrid level
                # if number of subgrids is equal to number of heads
                attention_single_subgrid = temp_attentions[count][0]
            else:
                attention_single_subgrid = lbann.Concatenation(
                    temp_attention,
                    axis=1,
                    name=f'{name}_subgrid_heads_concat{count}',
                    parallel_strategy={
                        'sub_branch_tag': 0,
                        'enable_subgraph': False
                    })

            attention_single_subgrid = lbann.ChannelwiseFullyConnected(
                attention_single_subgrid,
                weights=self.output_weights[count],
                output_channel_dims=[self.embed_dim],
                name=f'{name}_cfc_{count}',
            )

            attentions.append(attention_single_subgrid)

        #Strong scaling

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        return attentions
Ejemplo n.º 17
0
    def forward(self, x, memory, src_mask=None, tgt_mask=None):
        """Apply Transformer decoder layer.

        Args:
            x (lbann.Layer): Sequence of input vectors.
            memory (lbann.Layer): Sequence of vectors produced by
                Transformer encoder stack.
            src_mask (lbann.Layer, optional): Attention mask for
                second attention module (attends to both `x` and
                `memory`).
            tgt_mask (lbann.Layer, optional): Attention mask for first
                attention module (attends only to `x`).

        Returns:
            lbann.Layer: Sequence of output vectors.

        """
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Self-attention with residual connection
        y = self.attention1(x, x, x, mask=tgt_mask)
        if self.dropout_prob > 0:
            y = lbann.Dropout(
                y,
                keep_prob=1 - self.dropout_prob,
                name=f'{name}_drop1',
            )
        z = lbann.Sum(x, y, name=f'{name}_sum1')
        z = lbann.InstanceNorm(z, name=f'{name}_norm1')
        x = z

        # Attention on encoder output with residual connection
        y = self.attention2(x, memory, memory, mask=src_mask)
        if self.dropout_prob > 0:
            y = lbann.Dropout(
                y,
                keep_prob=1 - self.dropout_prob,
                name=f'{name}_drop2',
            )
        z = lbann.Sum(x, y, name=f'{name}_sum2')
        z = lbann.InstanceNorm(z, name=f'{name}_norm2')
        x = z

        # Feedforward network with residual connection
        y = lbann.ChannelwiseFullyConnected(
            x,
            weights=self.fc1_weights,
            output_channel_dims=[self.feedforward_dim],
            name=f'{name}_fc1',
        )
        y = lbann.Relu(y, name=f'{name}_relu1')
        if self.dropout_prob > 0:
            y = lbann.Dropout(
                y,
                keep_prob=1 - self.dropout_prob,
                name=f'{name}_drop3',
            )
        y = lbann.ChannelwiseFullyConnected(
            y,
            weights=self.fc2_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_fc2',
        )
        if self.dropout_prob > 0:
            y = lbann.Dropout(
                y,
                keep_prob=1 - self.dropout_prob,
                name=f'{name}_drop4',
            )
        z = lbann.Sum(x, y, name=f'{name}_sum3')
        z = lbann.InstanceNorm(z, name=f'{name}_norm3')
        return z