Example #1
0
def Cumsum(x, dims, axis=0):
    global _cumsum_cache

    if len(dims) != 2:
        raise RuntimeError("dims > 2 not tested/supported for cumsum")
    if (axis < 0) or (axis > 1):
        raise RuntimeError("Unsupported cumsum axis: {}".format(axis))
    shape = (dims[axis], dims[axis])
    if shape not in _cumsum_cache:
        tril_ones = np.tril(np.full(shape, 1, dtype=int), k=0)
        tril_ones = lbann.Weights(
            initializer=lbann.ValueInitializer(values=str_list(
                np.nditer(tril_ones, order="C")), ),
            optimizer=lbann.NoOptimizer(),
        )
        tril_ones = lbann.WeightsLayer(dims=str_list(shape), weights=tril_ones)
        _cumsum_cache[shape] = tril_ones

    # Apply cumsum
    tril_ones = _cumsum_cache[shape]
    if axis == 0:
        x = lbann.MatMul(tril_ones, x)
        return x
    if axis == 1:
        x = lbann.MatMul(x, tril_ones, transpose_b=True)
        return x
Example #2
0
    def forward(self, X, A):
        messages = lbann.MatMul(X, self.W2, name=self.name + '_w2_mult')
        messages = lbann.MatMul(A, messages, name=self.name + '_adj_mult')

        ident = lbann.MatMul(X, self.W1, name=self.name + '_w1_mult')

        out = lbann.Sum(ident, messages, name=self.name + '_sum_id')

        return out
Example #3
0
    def forward(self, motif_size, motif_log_embeddings):
        """Predict whether a motif is real.

        @todo Numerically accurate computation of both log(D) and
        log(1-D).

        """

        # D = 1 - exp(-sum_j(prod_i(d_ij)))
        # log(1-D) = -sum_j(exp(sum_i(log(d_ij))))
        x = lbann.MatMul(
            lbann.Constant(value=1, num_neurons=str_list([1, motif_size])),
            motif_log_embeddings,
        )
        x = lbann.Exp(x)
        x = lbann.Reduction(x, mode='sum')
        x = lbann.Negative(x)
        log_not_prob = x

        # Convert log-probability to linear space
        # Note: D=-expm1(x) is accurate when D~0. When D~1, prefer
        # 1-D=exp(x).
        prob = lbann.Negative(lbann.Expm1(log_not_prob))

        return prob, log_not_prob
Example #4
0
def mean_squared_error(
    data_dim,
    sequence_length,
    source_sequence,
    target_sequence,
    scale_decay=0.8,
):

    # Compute inner product between source and target vectors
    # Note: Inner products are computed for each (x,y) pair and a
    # weighted sum is computed. The scaling factors sum to 1 and decay
    # exponentially as x and y get further apart in the sequence.
    prods = lbann.MatMul(
        source_sequence,
        target_sequence,
        transpose_b=True,
    )
    scale_dims = (sequence_length, sequence_length)
    scales = np.zeros(scale_dims)
    for i in range(sequence_length):
        for j in range(sequence_length):
            if i != j:
                scales[i, j] = ((1 - scale_decay) / (2 * scale_decay) *
                                scale_decay**np.abs(j - i))
    scales = lbann.Weights(
        initializer=lbann.ValueInitializer(
            values=utils.str_list(np.nditer(scales))),
        optimizer=lbann.NoOptimizer(),
    )
    scales = lbann.WeightsLayer(dims=utils.str_list(scale_dims),
                                weights=scales)
    prods = lbann.MatMul(
        lbann.Reshape(prods, dims='1 -1'),
        lbann.Reshape(scales, dims='1 -1'),
        transpose_b=True,
    )
    prods = lbann.Reshape(prods, dims='1')

    # MSE(x,y) = ( norm(x)^2 + norm(y)^T - 2*prod(x,y) ) / dim(x)
    scale = 1 / (data_dim * sequence_length)
    return lbann.WeightedSum(lbann.L2Norm2(source_sequence),
                             lbann.L2Norm2(target_sequence),
                             prods,
                             scaling_factors=utils.str_list(
                                 [scale, scale, -2 * scale]))
Example #5
0
    def forward(self, X, A):
        """Call the GatedGraphConv
        Args:
            X (GraphVertexData): LBANN Data object, which is a collection of Layers. Each Layer is of
                                 the shape (1,input_channels) 
            A (Layer): Adjacency matrix input with shape (num_nodes, num_nodes)
        Returns: 
            LBANN_Data_Mat: The output after Gated Graph Kernel. 
                        The output can passed into another Graph Conv layer directly

        """

        input_features = X.size(1)
        num_nodes = X.size(0)

        if (input_features < self.output_channels):
            for i in range(num_nodes):
                num_zeros = self.output_channels - input_features
                zeros = lbann.Constant(value=0,
                                       num_neurons=str_list([1, num_zeros]),
                                       name=self.name + '_zero_' + str(i))
                X[i] = lbann.Concatenation(X[i], zeros, axis=1)
        elif (input_features > self.output_channels):
            ValueError(
                'The feature size of the nodes {} cannot be greater than the output dimension {}'
                .format(input_features, self.output_channels))

        X.update_num_features(self.output_channels)

        for layer in range(self.num_layers):
            ##
            X_mat = X.get_mat()
            messages = lbann.MatMul(X_mat, self.weights[layer])
            aggregate = lbann.MatMul(A, messages)

            M = GraphVertexData.matrix_to_graph(aggregate, num_nodes,
                                                self.output_channels)

            for i in range(num_nodes):
                X[i] = lbann.Reshape(X[i], dims=str(self.output_channels))
                X[i] = lbann.Reshape(self.rnn(M[i], X[i])[1],
                                     dims=str_list([1, self.output_channels]))

        return X
Example #6
0
    def forward(self, X, A):
        """Apply Graph Conv Layer to X and use A for message passing

        Args:
            X (GraphVertexData): LBANN Data object, which is a collection of Layers. Each Layer is of
                                 the shape (1,input_channels) 

            A (Layer): Adjacency matrix input with shape (num_nodes, num_nodes)

        Returns: 
            
            GraphVertexData: The output after convolution. The output can passed into another Graph Conv layer
                          directly
        """

        # Accumulate Messages from Neighboring Nodes
        out = X.get_mat()
        out = lbann.MatMul(out,
                           self.weights1,
                           name=self.name + "_Graph_MATMUL")
        message = lbann.MatMul(A, out, name=self.name + "_Graph_Message")
        message = GraphVertexData.matrix_to_graph(message, X.shape[0],
                                                  self.output_channels)

        # Assume X is a GraphVertexData object

        for node_feature in range(X.shape[0]):
            X[node_feature] = lbann.MatMul(X[node_feature], self.weights2)

        for node_feature in range(X.shape[0]):
            if (self.bias):
                message[node_feature] = lbann.Sum(
                    message[node_feature],
                    self.bias,
                    name=self.name + '_message_bias_' + str(node_feature))
            X[node_feature] = lbann.Sum(X[node_feature], message[node_feature])

        if self.activation:
            for node_feature in range(X.shape[0]):
                X[node_feature] = self.activation(X[node_feature])

        X.update_num_features(self.output_channels)
        return X
Example #7
0
def negative_samples_loss(embeddings, negative_samples_embeddings):
    scores = lbann.MatMul(
        embeddings,
        negative_samples_embeddings,
        transpose_b=True,
    )
    scores = lbann.WeightedSum(scores, scaling_factors='-1')
    scores = lbann.LogSigmoid(scores)
    loss = lbann.Reduction(scores, mode='average')
    loss = lbann.WeightedSum(loss, scaling_factors='-1')
    return loss
Example #8
0
def positive_samples_loss(
        sequence_length,
        encoder_embeddings,
        decoder_embeddings,
        scale_decay=0.8,
):

    # Compute similarity scores between encoder and decoder embeddings
    scores = lbann.MatMul(
        encoder_embeddings,
        decoder_embeddings,
        transpose_b=True,
    )
    scores = lbann.LogSigmoid(scores)

    # Scale similarity scores and add together
    # Note: The scaling factor decays exponentially as embeddings get
    # futher apart in the sequence.
    # Note: The sum of all the scaling factors is approximately -1.
    scale_dims = (sequence_length,sequence_length)
    scales = np.zeros(scale_dims)
    for i in range(sequence_length):
        for j in range(sequence_length):
            if i != j:
                scales[i,j] = (
                    -(1-scale_decay)/(2*scale_decay*sequence_length)
                    * scale_decay**np.abs(j-i)
                )
    scales = lbann.Weights(
        initializer=lbann.ValueInitializer(values=utils.str_list(np.nditer(scales))),
        optimizer=lbann.NoOptimizer(),
    )
    scales = lbann.WeightsLayer(dims=utils.str_list(scale_dims), weights=scales)
    loss = lbann.MatMul(
        lbann.Reshape(scores, dims='1 -1'),
        lbann.Reshape(scales, dims='1 -1'),
        transpose_b=True,
    )
    loss = lbann.Reshape(loss, dims='1')
    return loss
Example #9
0
    def forward(self, node_features_mat, edge_features_tensor,
                node_features_tensor, adjacency_tensor):

        num_edges = self.num_nodes**2

        edge_ft_shape = str_list(
            [num_edges, self.input_channels, self.output_channels])
        node_ft_tensor_shape = str_list(
            [self.num_nodes, self.num_nodes, self.output_channels])
        node_ft_mat_shape = str_list([self.num_nodes, self.output_channels])

        transformed_edge_ft_tensor = None

        for layer in self.edge_nn:
            if transformed_edge_ft_tensor is not None:
                transformed_edge_ft_tensor = layer(transformed_edge_ft_tensor)
            else:
                transformed_edge_ft_tensor = layer(edge_features_tensor)

        transformed_edge_ft_tensor = lbann.Reshape(transformed_edge_ft_tensor,
                                                   dims=edge_ft_shape,
                                                   name=self.name +
                                                   "_edge_ft_reshape")

        new_node_features = lbann.MatMul(node_features_tensor,
                                         transformed_edge_ft_tensor)
        new_node_features = lbann.Reshape(new_node_features,
                                          dims=node_ft_tensor_shape)

        gathered_node_features = lbann.MatMul(adjacency_tensor,
                                              new_node_features)

        new_node_features = lbann.Reshape(gathered_node_features,
                                          dims=node_ft_mat_shape)
        updated_nodes = self.node_nn(node_features_mat)

        out = lbann.Sum(new_node_features, updated_nodes)

        return out
Example #10
0
    def forward(self, X, A):
        """Apply GCN

        Args:
            X (GraphVertexData): LBANN Data object, which is a collection of Layers. Each Layer is of
                                 the shape (1,input_channels) 
            A (Layer): Adjacency matrix input with shape (num_nodes, num_nodes)
                                applied. The adjacency matrix is assumed to be normalized in the 
                                pre-processing step. 
        Returns:     
            LBANN_Data_Mat: The output after GCN. The output can passed into another Graph Conv layer
                          directly
        """

        # Assume X is a lbann data object
        for i in range(X.shape[0]):
            X[i] = lbann.MatMul(X[i],
                                self.W,
                                name=self.name + '_message_' + str(i))
            if (self.bias):
                X[i] = lbann.Sum(X[i],
                                 self.bias,
                                 name=self.name + '_message_bias_' + str(i))

        # Pass Message to Node Features
        out = X.get_mat(self.output_channels)

        # A - adjacency matrix is assumed to be normalized such that
        # A = D^-0.5 A D^0.5 as the convention in the GCN paper.
        out = lbann.MatMul(A, out, name=self.name + '_aggregate')

        if self.activation:
            out = self.activation(out)

        out = GraphVertexData.matrix_to_graph(out, X.shape[0],
                                              self.output_channels)

        return out
Example #11
0
    def message(self,
                node_features,
                neighbor_features,
                edge_features):
        """Update node features and edge features. The Message stage of the
           convolution.
        Args:
            node_features (Layer); A 2D layer of node features of
                                   shape (num_nodes, input_channels)
            neighbor_features (Layer): A 3D layer of node features of
                                       shape (num_edges, 1, input_channels)
            edge_features (Layer): A 2D layer of edge features of
                                   shape (num_edges, edge_features)
        Returns:
            (Layer, Layer): Returns the updated node features and the messages
            for each node.
        """

        ## These reshapes do not change the nn output but enables channelwise partitioning 
        ## for distconv channelwiseFC natively 
        
        node_features = lbann.Reshape(node_features, dims=str_list([self.num_nodes, 1, self.input_channels]))
        edge_features = lbann.Reshape(edge_features, dims=str_list([self.num_edges, 1, self.edge_input_channels]))

        updated_node_features = self.node_nn(node_features)

        edge_update = None
        for layer in self.edge_nn:

            if edge_update:
                edge_update = layer(edge_update)
            else:
                edge_update = layer(edge_features)

        edge_values = \
            lbann.Reshape(edge_update,
                          dims=str_list([self.num_edges,
                                         self.input_channels,
                                         self.output_channels]),
                          name=self.name+"_edge_mat_reshape")
        edge_values = \
            lbann.MatMul(neighbor_features, edge_values)
        return updated_node_features, edge_values
Example #12
0
    def forward(self, X, A, activation = lbann.Relu):
        """Apply GIN  Layer. 
        
        Args:
            X (GraphVertexData): LBANN Data object, which is a collection of Layers. Each Layer is of
                                 the shape (1,input_channels) 

            A (Layer): Adjacency matrix input with shape (num_nodes, num_nodes)

            activation (Layer): Activation layer for the node features. If None, then no activation is 
                                applied. (default: lbann.Relu) 
        Returns: 
            
            (GraphVertexData): The output after GCN. The output can passed into another Graph Conv layer
                          directly
        """
        in_channel = X.shape[1]

        # Accumulate Messages from Neighboring Nodes
        out = X.get_mat()
        out = lbann.MatMul(A,out, name = self.name+"_GIN_MATMUL")
        message = GraphVertexData.matrix_to_graph(out, X.shape[0], in_channel)

        # Aggregate Messages into node features  
        eps = lbann.Constant(value=(1+self.eps),num_neurons = str_list([1, in_channel]))
        for node_feature in range(X.shape[0]):
            eps_val = lbann.Multiply(eps, X[node_feature])
            X[node_feature] = lbann.Sum(message[node_feature], eps_val)
        
        # Transform with the sequence of linear layers
        for layer in self.nn:
            for node_feature in range(X.shape[0]):
                X[node_feature] = layer(X[node_feature])
        
        ## Apply activation 
        if activation:
            for node_feature in range(X.shape[0]):
                X[node_feature] = activation(X[node_feature])
        X.update_num_features(self.output_channels) 
        return X
Example #13
0
def PytorchMatmul(x, x_shape, y, y_shape, return_dims=False):
    if len(x_shape) != len(y_shape):
        raise RuntimeError(
            "Broadcasting not fully implemented, tensors must have same dimension"
        )
    need_reshape = (len(x_shape) > 3) and (len(y_shape) > 3)
    if need_reshape:
        if x_shape[:-2] != y_shape[:-2]:
            raise RuntimeError("The first n-2 dimensions must match")
        new_x_shape = (np.prod(x_shape[:-2]), ) + x_shape[-2:]
        x = lbann.Reshape(x, dims=str_list(new_x_shape))

        new_y_shape = (np.prod(y_shape[:-2]), ) + y_shape[-2:]
        y = lbann.Reshape(y, dims=str_list(new_y_shape))

    z = lbann.MatMul(x, y)

    z_shape = x_shape[:-1] + (y_shape[-1], )
    if need_reshape:
        z = lbann.Reshape(z, dims=str_list(z_shape))

    if return_dims:
        return z, z_shape
    return z
Example #14
0
def make_model(num_vertices=None,
               node_features=None,
               num_classes=None,
               kernel_type='GCN',
               callbacks=None,
               num_epochs=1):
    '''Construct a model DAG using one of the Graph Kernels

    Args:
        num_vertices (int): Number of vertices of each graph (default: None)
        node_features (int): Number of features per noded (default: None)
        num_classes (int): Number of classes as targets (default: None)
        kernel_type (str): Graph Kernel to use in model. Expected one of
                            GCN, or Graph (deafult: GCN)
        callbacks (list): Callbacks for the model. If set to None the model description,
                          GPU usage, training_output, and timer is reported.
                          (default: None)
        num_epochs (int): Number of epochs to run (default: 1)
    Returns:
        (lbann Model Object: A model object with the supplied callbacks, dataset
                               presets, and graph kernels.
    '''

    num_vertices = 100
    num_classes = 2
    node_features = 3

    assert num_vertices is not None
    assert num_classes is not None
    assert node_features is not None

    #----------------------------------
    # Reshape and Slice Input Tensor
    #----------------------------------

    input_ = lbann.Input(data_field='samples')

    # Input dimensions should be (num_vertices * node_features + num_vertices^2 + num_classes )
    # input should have atleast two children since the target is classification

    sample_dims = num_vertices * node_features + (num_vertices**
                                                  2) + num_classes
    graph_dims = num_vertices * node_features + (num_vertices**2)
    feature_matrix_size = num_vertices * node_features

    graph_input = lbann.Slice(input_,
                              axis=0,
                              slice_points=str_list([
                                  0, feature_matrix_size, graph_dims,
                                  sample_dims
                              ]),
                              name="Graph_Input")

    feature_matrix = lbann.Reshape(graph_input,
                                   dims=str_list([num_vertices,
                                                  node_features]),
                                   name="Node_features")

    adj_matrix = lbann.Reshape(graph_input,
                               dims=str_list([num_vertices, num_vertices]),
                               name="Adj_Mat")

    target = lbann.Identity(graph_input, name="Target")
    target = lbann.Reshape(target, dims=str(num_classes))

    #----------------------------------
    # Perform Graph Convolution
    #----------------------------------

    if kernel_type == 'GCN':
        x = DGCN_layer(feature_matrix, adj_matrix, node_features)
    elif kernel_type == 'Graph':
        x = DGraph_Layer(feature_matrix, adj_matrix, node_features)
    else:
        ValueError(
            'Invalid Graph kernel specifier "{}" recieved. Expected one of:\
                    GCN or Graph'.format(kernel_type))
    out_channel = 256
    #----------------------------------
    # Apply Reduction on Node Features
    #----------------------------------

    average_vector = lbann.Constant(value=1 / num_vertices,
                                    num_neurons=str_list([1, num_vertices]),
                                    name="Average_Vector")
    x = lbann.MatMul(average_vector, x, name="Node_Feature_Reduction"
                     )  # X is now a vector with output_channel dimensions

    x = lbann.Reshape(x, dims=str_list([out_channel]), name="Squeeze")
    x = lbann.FullyConnected(x, num_neurons=256, name="hidden_layer_1")
    x = lbann.Relu(x, name="hidden_layer_1_activation")
    x = lbann.FullyConnected(x,
                             num_neurons=num_classes,
                             name="Output_Fully_Connected")

    #----------------------------------
    # Loss Function and Accuracy s
    #----------------------------------

    probs = lbann.Softmax(x, name="Softmax")
    loss = lbann.CrossEntropy(probs, target, name="Cross_Entropy_Loss")
    accuracy = lbann.CategoricalAccuracy(probs, target, name="Accuracy")

    layers = lbann.traverse_layer_graph(input_)
    if callbacks is None:
        print_model = lbann.CallbackPrintModelDescription(
        )  #Prints initial Model after Setup
        training_output = lbann.CallbackPrint(
            interval=1,
            print_global_stat_only=False)  #Prints training progress
        gpu_usage = lbann.CallbackGPUMemoryUsage()
        timer = lbann.CallbackTimer()
        callbacks = [print_model, training_output, gpu_usage, timer]
    else:
        if isinstance(callbacks, list):
            callbacks = callbacks
    metrics = [lbann.Metric(accuracy, name='accuracy', unit="%")]

    model = lbann.Model(num_epochs,
                        layers=layers,
                        objective_function=loss,
                        metrics=metrics,
                        callbacks=callbacks)
    return model
Example #15
0
)
decoder_embeddings = lbann.Embedding(
    input_slice,
    weights=decoder_embeddings_weights,
    num_embeddings=num_graph_nodes,
    embedding_dim=args.latent_dim,
)
encoder_embeddings = lbann.Embedding(
    input_slice,
    weights=encoder_embeddings_weights,
    num_embeddings=num_graph_nodes,
    embedding_dim=args.latent_dim,
)

# Skip-Gram with negative sampling
preds = lbann.MatMul(decoder_embeddings, encoder_embeddings, transpose_b=True)
preds_slice = lbann.Slice(
    preds,
    axis=0,
    slice_points=f'0 {num_negative_samples} {num_negative_samples+1}')
preds_negative = lbann.Identity(preds_slice)
preds_positive = lbann.Identity(preds_slice)
obj_positive = lbann.LogSigmoid(preds_positive)
obj_positive = lbann.Reduction(obj_positive, mode='sum')
obj_negative = lbann.WeightedSum(preds_negative, scaling_factors='-1')
obj_negative = lbann.LogSigmoid(obj_negative)
obj_negative = lbann.Reduction(obj_negative, mode='sum')
obj = [
    lbann.LayerTerm(obj_positive, scale=-1),
    lbann.LayerTerm(obj_negative, scale=-1/num_negative_samples),
]
Example #16
0
    def compute_loss(self, x, y):

        # y[:, :-1]
        y = lbann.Slice(
            y,
            axis=0,
            slice_points=str_list([0, self.input_feature_dims - 1]),
        )
        y = lbann.Identity(y)

        # x[:, 1:]
        x = lbann.Slice(
            x,
            slice_points=str_list([1, self.input_feature_dims]),
        )
        x = lbann.Identity(x)

        # Convert indices in x to one-hot representation
        # Note: Ignored indices result in zero vectors
        ignore_mask = lbann.Equal(
            x,
            self.constant(self.label_to_ignore, hint_layer=x),
        )
        keep_mask = lbann.LogicalNot(ignore_mask)
        length = lbann.Reduction(keep_mask, mode='sum')
        length = lbann.Max(length, self.constant(1, [1]))
        x = lbann.Add(
            lbann.Multiply(keep_mask, x),
            lbann.Multiply(ignore_mask, self.constant(-1, hint_layer=x)),
        )
        x = lbann.Slice(x,
                        slice_points=str_list(range(self.input_feature_dims)))
        x = [lbann.Identity(x) for _ in range(self.input_feature_dims - 1)]
        x = [lbann.OneHot(xi, size=self.dictionary_size) for xi in x]
        x = [
            lbann.Reshape(xi, dims=str_list([1, self.dictionary_size]))
            for xi in x
        ]
        x = lbann.Concatenation(x, axis=0)

        # recon_loss = F.cross_entropy(
        #     y[:, :-1].contiguous().view(-1, y.size(-1)),
        #     x[:, 1:].contiguous().view(-1),
        #     ignore_index=self.pad
        # )
        # Note: Ideally we'd shift y by y.max(-1) for numerical stability
        shifts = lbann.MatMul(
            lbann.Max(y, self.constant(0, hint_layer=y)),
            self.constant(
                1 / math.sqrt(self.dictionary_size),
                [self.dictionary_size, self.dictionary_size],
            ),
        )
        y = lbann.Subtract(y, shifts)
        z = lbann.MatMul(
            lbann.Exp(y),
            self.constant(1, [self.dictionary_size, 1]),
        )
        z = lbann.Log(z)
        z = lbann.MatMul(
            lbann.Reshape(keep_mask, dims=str_list([1, -1])),
            z,
        )
        recon_loss = lbann.MatMul(
            lbann.Reshape(y, dims=str_list([1, -1])),
            lbann.Reshape(x, dims=str_list([1, -1])),
            transpose_b=True,
        )
        recon_loss = lbann.Subtract(z, recon_loss)
        recon_loss = lbann.Reshape(recon_loss, dims=str_list([1]))
        recon_loss = lbann.Divide(recon_loss, length)

        return recon_loss
Example #17
0
    # LBANN implementation
    lbann_x = lbann.WeightsLayer(
        weights=lbann.Weights(
            lbann.ValueInitializer(values=str_list(np_x.flatten())), ),
        dims=str_list(np_x.shape),
    )
    lbann_y = FFTShift()(lbann_x, dims)
    lbann_scales = lbann.WeightsLayer(
        weights=lbann.Weights(
            lbann.ValueInitializer(values=str_list(np_scales)),
            optimizer=lbann.NoOptimizer(),
        ),
        dims=str_list(np_scales.shape),
    )
    lbann_z = lbann.MatMul(lbann.Reshape(lbann_y, dims=str_list([1, -1])),
                           lbann.Reshape(lbann_scales, dims=str_list([-1, 1])))

    # Construct LBANN model with metric checking and gradient checking
    metric = lbann.Metric(lbann_z, name='metric')
    callbacks = [
        lbann.CallbackCheckMetric(
            metric=metric.name,
            lower_bound=np_z - tol,
            upper_bound=np_z + tol,
            error_on_failure=True,
            execution_modes='test',
        ),
        lbann.CallbackCheckGradients(error_on_failure=True),
    ]
    model = lbann.Model(
        epochs=0,
Example #18
0
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = lbann.ChannelwiseFullyConnected(
            queries,
            weights=self.query_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_queries_fc',
        )
        keys_fc = lbann.ChannelwiseFullyConnected(
            keys,
            weights=self.key_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_keys_fc',
        )
        values_fc = lbann.ChannelwiseFullyConnected(
            values,
            weights=self.value_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_values_fc',
        )

        # Slice embedding vectors for each head
        slice_points = str_list(self.head_dim * i
                                for i in range(self.num_heads + 1))
        queries_slice = lbann.Slice(
            queries_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_queries_slice',
        )
        keys_slice = lbann.Slice(
            keys_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_keys_slice',
        )
        values_slice = lbann.Slice(
            values_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_values_slice',
        )

        # Compute scaled dot-product attention for each head
        attentions = []
        for head in range(self.num_heads):
            head_name = f'{name}_head{head}'

            # Attention inputs
            q = lbann.Identity(queries_slice)
            k = lbann.Identity(keys_slice)
            v = lbann.Identity(values_slice)

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )
            if mask:
                y = lbann.Add(y, mask, name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim
            attentions.append(lbann.MatMul(y, v, name=head_name))

        # Concatenate heads and apply fully-connected layer
        attentions = lbann.Concatenation(attentions,
                                         axis=1,
                                         name=f'{name}_heads_concat')
        outputs_fc = lbann.ChannelwiseFullyConnected(
            attentions,
            weights=self.output_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}',
        )
        return outputs_fc
Example #19
0
def make_model(num_vertices=None,
               node_features=None,
               num_classes=None,
               kernel_type='GCN',
               callbacks=None,
               num_epochs=1):
    '''Construct a model DAG using one of the Graph Kernels

    Args:
        num_vertices (int): Number of vertices of each graph (default: None)
        node_features (int): Number of features per noded (default: None)
        num_classes (int): Number of classes as targets (default: None)

        kernel_type (str): Graph Kernel to use in model. Expected one of
                            GCN, GIN, Graph, or GatedGraph (deafult: GCN)
        callbacks (list): Callbacks for the model. If set to None the model description,
                          GPU usage, training_output, and timer is reported.
                          (default: None)
        num_epochs (int): Number of epochs to run (default: 1)
    Returns:
        (lbann.Model) : A model object with the supplied callbacks, dataset
                               presets, and graph kernels.
    '''

    num_vertices = 100
    num_classes = 2
    node_feature_size = 3
    max_edges = 415

    #----------------------------------
    # Reshape and Slice Input Tensor
    #----------------------------------

    input_ = lbann.Input(data_field='samples')

    # Input dimensions should be (num_vertices * node_features + num_vertices^2 + num_classes )

    data = Graph_Data_Parser(input_, num_vertices, node_feature_size,
                             max_edges, num_classes)

    feature_matrix = data['node_features']
    source_indices = data['source_indices']
    target_indices = data['target_indices']
    target = data['target']

    #----------------------------------
    # Select Graph Convolution
    #----------------------------------

    output_channels = 16
    graph_kernel_op = None
    if kernel_type == 'GIN':
        graph_kernel_op = GINConvLayer
    elif kernel_type == 'GCN':
        graph_kernel_op = GCNConvLayer
    elif kernel_type == 'Graph':
        graph_kernel_op = GraphConvLayer
    elif kernel_type == 'GatedGraph':
        graph_kernel_op = GATConvLayer
    else:
        raise ValueError(
            'Invalid Graph kernel specifier "{}" recieved. Expected one of:\
                    GIN,GCN,Graph or GatedGraph'.format(kernel_type))
    #----------------------------------
    # Perform Graph Convolution
    #----------------------------------

    x = graph_kernel_op(feature_matrix, source_indices, target_indices,
                        num_vertices, max_edges, node_feature_size,
                        output_channels)
    #----------------------------------
    # Apply Reduction on Node Features
    #----------------------------------

    average_vector = lbann.Constant(value=1 / num_vertices,
                                    num_neurons=str_list([1, num_vertices]),
                                    name="Average_Vector")

    x = lbann.MatMul(average_vector, x, name="Node_Feature_Reduction")

    # X is now a vector with output_channel dimensions

    x = lbann.Reshape(x, dims=str_list([output_channels]), name="Squeeze")
    x = lbann.FullyConnected(x, num_neurons=64, name="hidden_layer_1")
    x = lbann.Relu(x, name="hidden_layer_1_activation")
    x = lbann.FullyConnected(x,
                             num_neurons=num_classes,
                             name="Output_Fully_Connected")

    #----------------------------------
    # Loss Function and Accuracy s
    #----------------------------------

    probs = lbann.Softmax(x, name="Softmax")
    loss = lbann.CrossEntropy(probs, target, name="Cross_Entropy_Loss")
    accuracy = lbann.CategoricalAccuracy(probs, target, name="Accuracy")

    layers = lbann.traverse_layer_graph(input_)

    if callbacks is None:
        print_model = lbann.CallbackPrintModelDescription(
        )  #Prints initial Model after Setup
        training_output = lbann.CallbackPrint(
            interval=1,
            print_global_stat_only=False)  #Prints training progress
        gpu_usage = lbann.CallbackGPUMemoryUsage()
        timer = lbann.CallbackTimer()
        callbacks = [print_model, training_output, gpu_usage, timer]
    else:
        if isinstance(callbacks, list):
            callbacks = callbacks

    metrics = [lbann.Metric(accuracy, name='accuracy', unit="%")]

    model = lbann.Model(num_epochs,
                        layers=layers,
                        objective_function=loss,
                        metrics=metrics,
                        callbacks=callbacks)
    return model
Example #20
0
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH
        BRANCHES = self.BRANCHES
        if (ENABLE_SUBGRAPH):
            if (self.num_heads % BRANCHES != 0):
                raise ValueError('Num heads should be divisible by BRANCHES')
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = lbann.ChannelwiseFullyConnected(
            queries,
            weights=self.query_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_queries_fc',
        )
        keys_fc = lbann.ChannelwiseFullyConnected(
            keys,
            weights=self.key_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_keys_fc',
        )
        values_fc = lbann.ChannelwiseFullyConnected(
            values,
            weights=self.value_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_values_fc',
        )

        # Slice embedding vectors for each head
        slice_points = str_list(self.head_dim * i
                                for i in range(self.num_heads + 1))
        queries_slice = lbann.Slice(queries_fc,
                                    axis=1,
                                    slice_points=slice_points,
                                    name=f'{name}_queries_slice',
                                    parallel_strategy={
                                        'sub_branch_tag': 0,
                                        'enable_subgraph': ENABLE_SUBGRAPH
                                    })
        keys_slice = lbann.Slice(keys_fc,
                                 axis=1,
                                 slice_points=slice_points,
                                 name=f'{name}_keys_slice',
                                 parallel_strategy={
                                     'sub_branch_tag': 0,
                                     'enable_subgraph': ENABLE_SUBGRAPH
                                 })
        values_slice = lbann.Slice(values_fc,
                                   axis=1,
                                   slice_points=slice_points,
                                   name=f'{name}_values_slice',
                                   parallel_strategy={
                                       'sub_branch_tag': 0,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })

        # Compute scaled dot-product attention for each head
        attentions = []
        tag = 0
        for head in range(self.num_heads):
            head_name = f'{name}_myattention_head{head}'

            # Attention inputs

            if (ENABLE_SUBGRAPH):
                if (head % int(self.num_heads / BRANCHES) == 0):
                    tag += 1

                q = lbann.Identity(queries_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
                k = lbann.Identity(keys_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
                v = lbann.Identity(values_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
            else:
                q = lbann.Identity(queries_slice)
                k = lbann.Identity(keys_slice)
                v = lbann.Identity(values_slice)

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )

            if (ENABLE_SUBGRAPH):
                if mask != None:
                    y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask')
            else:
                if mask:
                    y = lbann.Sum([y, mask], name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim

            attentions.append(lbann.MatMul(y, v, name=head_name))

            #Strong scaling

        # Concatenate heads and apply fully-connected layer
        if (ENABLE_SUBGRAPH):
            attentions = lbann.Concatenation(attentions,
                                             axis=1,
                                             name=f'{name}_heads_concat',
                                             parallel_strategy={
                                                 'sub_branch_tag': 0,
                                                 'enable_subgraph':
                                                 ENABLE_SUBGRAPH
                                             })
        else:
            attentions = lbann.Concatenation(
                attentions,
                axis=1,
                name=f'{name}_heads_concat',
            )

        outputs_fc = lbann.ChannelwiseFullyConnected(
            attentions,
            weights=self.output_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}',
        )
        return outputs_fc
Example #21
0
 def forward(self, X, A):
     out = lbann.MatMul(X, self.W, name=self.name + '_weight_mult')
     out = lbann.MatMul(A, out, name=self.name + '_adj_mult')
     return out
Example #22
0
    def compute_loss(self, x, y):

        # y[:, :-1]
        y = lbann.Slice(
            y,
            axis=0,
            slice_points=str_list([0, self.input_feature_dims-1]),
        )
        y = lbann.Identity(y)

        # x[:, 1:]
        x = lbann.Slice(
            x,
            slice_points=str_list([1, self.input_feature_dims]),
        )
        x = lbann.Identity(x)

        # Figure out entries in x to ignore
        ignore_mask = lbann.Equal(
            x,
            self.constant(self.label_to_ignore, hint_layer=x),
        )
        keep_mask = lbann.LogicalNot(ignore_mask)
        length = lbann.Reduction(keep_mask, mode='sum')
        length = lbann.Max(length, self.constant(1, [1]))

        # Convert entries in x to indices in y
        # Note: Ignored entries correspond to an index of -1.
        offsets = [
            row*self.dictionary_size
            for row in range(self.input_feature_dims-1)
        ]
        offsets = lbann.Weights(
            initializer=lbann.ValueInitializer(values=str_list(offsets)),
            optimizer=lbann.NoOptimizer(),
        )
        offsets = lbann.WeightsLayer(
            dims=str_list([self.input_feature_dims-1]),
            weights=offsets,
        )
        y_inds = lbann.Add(x, offsets)
        y_inds = lbann.Add(
            lbann.Multiply(keep_mask, y_inds),
            lbann.Multiply(
                ignore_mask,
                self.constant(-1, hint_layer=y_inds),
            ),
        )

        # recon_loss = F.cross_entropy(
        #     y[:, :-1].contiguous().view(-1, y.size(-1)),
        #     x[:, 1:].contiguous().view(-1),
        #     ignore_index=self.pad
        # )

        # Shift y for numerical stability
        # Note: We'd prefer to shift by y.max(-1)
        shifts = lbann.MatMul(
            lbann.Max(y, self.constant(0, hint_layer=y)),
            self.constant(
                1 / math.sqrt(self.dictionary_size),
                [self.dictionary_size, self.dictionary_size],
            ),
        )
        y = lbann.Subtract(y, shifts)

        # Compute log of softmax denominator and sum
        z = lbann.MatMul(
            lbann.Exp(y),
            self.constant(1, [self.dictionary_size, 1]),
        )
        z = lbann.Log(z)
        z = lbann.MatMul(
            lbann.Reshape(keep_mask, dims=str_list([1, -1])),
            z,
        )
        z = lbann.Reshape(z, dims=str_list([1]))

        # Compute cross entropy
        recon_loss = lbann.Gather(
            lbann.Reshape(y, dims=str_list([-1])),
            y_inds,
        )
        recon_loss = lbann.Reduction(recon_loss, mode='sum')
        recon_loss = lbann.Subtract(z, recon_loss)
        recon_loss = lbann.Divide(recon_loss, length)

        return recon_loss
Example #23
0
def make_model(num_vertices=None,
               node_features=None,
               num_classes=None,
               dataset=None,
               kernel_type='GCN',
               callbacks=None,
               num_epochs=1):
    '''Construct a model DAG using one of the Graph Kernels

    Args:
        num_vertices (int): Number of vertices of each graph (default: None) 
        node_features (int): Number of features per noded (default: None)
        num_classes (int): Number of classes as targets (default: None)
        dataset (str): Preset data set to use. Either a datset parameter has to be 
                       supplied or all of num_vertices, node_features, and 
                       num_classes have to be supplied. (default: None) 
        kernel_type (str): Graph Kernel to use in model. Expected one of 
                            GCN, GIN, Graph, or GatedGraph (deafult: GCN)
        callbacks (list): Callbacks for the model. If set to None the model description, 
                          GPU usage, training_output, and timer is reported. 
                          (default: None)                    
        num_epochs (int): Number of epochs to run (default: 1)
    Returns:
        (lbann Model Object: A model object with the supplied callbacks, dataset
                               presets, and graph kernels. 
    '''

    assert num_vertices != dataset  #Ensure atleast one of the values is set

    if dataset is not None:
        assert num_vertices is None

        if dataset == 'MNIST':
            num_vertices = 75
            num_classes = 10
            node_features = 1

        elif dataset == 'PROTEINS':
            num_vertices = 100
            num_classes = 2
            node_features = 3
        else:
            raise Exception("Unkown Dataset")

    assert num_vertices is not None
    assert num_classes is not None
    assert node_features is not None

    #----------------------------------
    # Reshape and Slice Input Tensor
    #----------------------------------

    input_ = lbann.Input(target_mode='classification')

    # Input dimensions should be (num_vertices * node_features + num_vertices^2 + num_classes )
    # Input should have atleast two children since the target is classification

    data = lbann_Graph_Data(input_, num_vertices, node_features, num_classes)

    feature_matrix = data.x
    adj_matrix = data.adj
    target = data.y

    #----------------------------------
    # Perform Graph Convolution
    #----------------------------------

    if kernel_type == 'GIN':
        x = GINConvLayer(feature_matrix, adj_matrix)
    elif kernel_type == 'GCN':
        x = GCNConvLayer(feature_matrix, adj_matrix)
    elif kernel_type == 'Graph':
        x = GraphConvLayer(feature_matrix, adj_matrix)
    elif kernel_type == 'GatedGraph':
        x = GATConvLayer(feature_matrix, adj_matrix)
    else:
        ValueError(
            'Invalid Graph kernel specifier "{}" recieved. Expected one of:\
                    GIN,GCN,Graph or GatedGraph'.format(kernel_type))

    out_channel = x.shape[1]
    #----------------------------------
    # Apply Reduction on Node Features
    #----------------------------------

    average_vector = lbann.Constant(value=1 / num_vertices,
                                    num_neurons=str_list([1, num_vertices]),
                                    name="Average_Vector")
    x = x.get_mat(out_channel)

    x = lbann.MatMul(average_vector, x, name="Node_Feature_Reduction")

    # X is now a vector with output_channel dimensions

    x = lbann.Reshape(x, dims=str_list([out_channel]), name="Squeeze")
    x = lbann.FullyConnected(x, num_neurons=64, name="hidden_layer_1")
    x = lbann.Relu(x, name="hidden_layer_1_activation")
    x = lbann.FullyConnected(x,
                             num_neurons=num_classes,
                             name="Output_Fully_Connected")

    #----------------------------------
    # Loss Function and Accuracy s
    #----------------------------------

    probs = lbann.Softmax(x, name="Softmax")
    loss = lbann.CrossEntropy(probs, target, name="Cross_Entropy_Loss")
    accuracy = lbann.CategoricalAccuracy(probs, target, name="Accuracy")

    layers = lbann.traverse_layer_graph(input_)

    if callbacks is None:
        print_model = lbann.CallbackPrintModelDescription(
        )  #Prints initial Model after Setup
        training_output = lbann.CallbackPrint(
            interval=1,
            print_global_stat_only=False)  #Prints training progress
        gpu_usage = lbann.CallbackGPUMemoryUsage()
        timer = lbann.CallbackTimer()
        callbacks = [print_model, training_output, gpu_usage, timer]
    else:
        if isinstance(callbacks, list):
            callbacks = callbacks

    metrics = [lbann.Metric(accuracy, name='accuracy', unit="%")]

    model = lbann.Model(num_epochs,
                        layers=layers,
                        objective_function=loss,
                        metrics=metrics,
                        callbacks=callbacks)
    return model
Example #24
0
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH
        BRANCHES = self.BRANCHES
        if (ENABLE_SUBGRAPH):
            if (self.num_heads % BRANCHES != 0):
                raise ValueError('Num heads should be divisible by BRANCHES')
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = []
        keys_fc = []
        values_fc = []

        # Slice embedding vectors for each head
        slice_points = str_list(
            self.head_dim * i
            for i in range(int(self.num_heads / self.BRANCHES) + 1))

        #Queries strong scaling in CFC
        attentions = []
        for count, query in enumerate(queries):
            temp = lbann.ChannelwiseFullyConnected(
                query,
                weights=self.query_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_queries_fc',
            )
            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):
            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_queries_slice',
            )

            queries_fc.append(temp)

        #keys strong scaling in CFC

        attentions = []
        for count, key in enumerate(keys):
            temp = lbann.ChannelwiseFullyConnected(
                key,
                weights=self.key_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_keys_fc',
            )

            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):

            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_keys_slice',
            )

            keys_fc.append(temp)

        #Values strong scaling in CFC
        attentions = []

        for count, value in enumerate(values):
            temp = lbann.ChannelwiseFullyConnected(
                value,
                weights=self.value_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_values_fc',
            )
            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):
            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_values_slice',
            )
            values_fc.append(temp)

        queries_slice = []
        keys_slice = []
        values_slice = []

        for branch in range(self.BRANCHES):
            querie_slice = queries_fc[branch]
            key_slice = keys_fc[branch]
            value_slice = values_fc[branch]

            for head in range(int(self.num_heads / self.BRANCHES)):
                queries_slice.append(lbann.Identity(querie_slice))
                keys_slice.append(lbann.Identity(key_slice))
                values_slice.append(lbann.Identity(value_slice))

        # Compute scaled dot-product attention for each head
        attentions = []

        #variable to combine heads locally in sub-grids
        temp_attentions = []
        tag = 0
        for head in range(self.num_heads):
            head_name = f'{name}_myattention_head{head}'

            # Attention inputs
            if (head % int(self.num_heads / BRANCHES) == 0):
                temp_attentions.append([])
                tag += 1

            q = lbann.Identity(queries_slice[head])
            k = lbann.Identity(keys_slice[head])
            v = lbann.Identity(values_slice[head])

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )

            if (ENABLE_SUBGRAPH):
                if mask != None:
                    y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask')
            else:
                if mask:
                    y = lbann.Sum([y, mask], name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim
            y = lbann.MatMul(y, v, name=head_name)
            # attentions.append(lbann.MatMul(y, v, name=head_name))

            temp_attentions[-1].append(y)

        for count, temp_attention in enumerate(temp_attentions):

            if (self.BRANCHES == self.num_heads):
                # No need to concat the heads at subgrid level
                # if number of subgrids is equal to number of heads
                attention_single_subgrid = temp_attentions[count][0]
            else:
                attention_single_subgrid = lbann.Concatenation(
                    temp_attention,
                    axis=1,
                    name=f'{name}_subgrid_heads_concat{count}',
                    parallel_strategy={
                        'sub_branch_tag': 0,
                        'enable_subgraph': False
                    })

            attention_single_subgrid = lbann.ChannelwiseFullyConnected(
                attention_single_subgrid,
                weights=self.output_weights[count],
                output_channel_dims=[self.embed_dim],
                name=f'{name}_cfc_{count}',
            )

            attentions.append(attention_single_subgrid)

        #Strong scaling

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        return attentions