Esempio n. 1
0
    def forward(self, x, dims):
        """Apply fftshift.

        Args:
            x (lbann.Layer): Input tensor
            dims (tuple of int): Dimensions of x (dim 0 corresponds to
                channel)

        Returns:
            Layer: Output tensor

        """

        # Get gather indices by applying fftshift to tensor filled with indices
        # Note: Independent fftshift for each channel (dim 0)
        spatial_size = np.prod(dims[1:])
        spatial_inds = np.arange(spatial_size).reshape(dims[1:])
        spatial_inds = np.fft.fftshift(spatial_inds)
        channel_offsets = np.arange(0, dims[0] * spatial_size, spatial_size)
        channel_offsets = channel_offsets.reshape([-1] +
                                                  [1] * spatial_inds.ndim)
        inds = np.expand_dims(spatial_inds, 0) + channel_offsets

        # Construct LBANN layer graph
        size = np.prod(dims)
        x = lbann.Reshape(x, dims=str_list([size]))
        inds = lbann.WeightsLayer(
            weights=lbann.Weights(
                lbann.ValueInitializer(values=str_list(inds.flatten())),
                optimizer=lbann.NoOptimizer(),
            ),
            dims=str_list([size]),
        )
        y = lbann.Gather(x, inds)
        return lbann.Reshape(y, dims=str_list(dims))
Esempio n. 2
0
def Permute(x, dims, axes=None, name="", return_dims=False):
    global _permute_cache
    key = (dims, axes)
    size = np.prod(dims)
    if key not in _permute_cache:
        # Construct gather indices
        inds = np.arange(size).reshape(dims, order="C").transpose(axes)
        inds = lbann.Weights(
            initializer=lbann.ValueInitializer(values=str_list(
                np.nditer(inds, order="C")), ),
            optimizer=lbann.NoOptimizer(),
        )
        inds = lbann.WeightsLayer(dims=str_list([size]), weights=inds)
        _permute_cache[key] = inds

    # Apply transpose with gather
    inds = _permute_cache[key]
    if axes == None:
        new_dims = dims[::-1]
    else:
        new_dims = np.array(dims)[list(axes)]
    x = lbann.Reshape(x, dims=str_list([size]))
    y = lbann.Gather(x, inds)
    y = lbann.Reshape(y, dims=str_list(list(new_dims)), name=name)

    if return_dims:
        return y, tuple(new_dims)
    return y
Esempio n. 3
0
def GraphExpand(features, indices, name=None):
    """Places the features according the indices to an expanded matrix

       output[i] = features[indices[i]]

       Args:
            features (Layer) : 2D matrix with shape (N, F)
            indices (Layer): 1D matrix with shape (E)
       returnL (Layer) of shape (E,F)
    """
    GraphExpand.count += 1
    if (name is None):
        name = f"graph_expand_{GraphExpand.count}" 
    return lbann.Gather(features, indices, axis=0, name=name)
Esempio n. 4
0
def graph_data_splitter(_input, NUM_NODES, NUM_EDGES, NUM_NODE_FEATURES,
                        NUM_EDGE_FEATURES, EMBEDDING_DIM, EDGE_EMBEDDING_DIM):
    """Helper function to split the input data into

			Args:
				NUM_NODES (int): The number of nodes in the largest graph in the dataset (51 for LSC-PPQM4M)
		      	NUM_EDGES (int): The number of edges in the largest graph in the dataset (118 for LSC-PPQM4M)
		      	NUM_NODE_FEATURES (int): The dimensionality of the input node features vector (9 for LSC-PPQM4M)
		      	NUM_EDGE_FEATURES (int): The dimensionality of the input edge feature vectors (3 for LSC-PPQM4M)
		      	EMBEDDING_DIM (int): The embedding dimensionality of the node feature vector

		      	EDGE_EMBEDDING_DIM (int): The embedding dimensionality of the edge feature vector
			Returns:
				(Layer, Layer, Layer, Layer, Layer): Returns 5 Layers. The embedded node feature matrix, the
													 neighbord nodes feature tensor, the embedded edge feature matrix,
													 the source node index vector, and the label
		"""
    split_indices = []

    start_index = 0
    split_indices.append(start_index)

    node_feature = [NUM_NODES for i in range(1, NUM_NODE_FEATURES + 1)]

    split_indices.extend(node_feature)

    edge_features = [NUM_EDGES for i in range(1, NUM_EDGE_FEATURES + 1)]

    split_indices.extend(edge_features)

    edge_indices_sources = NUM_EDGES
    split_indices.append(edge_indices_sources)

    edge_indices_targets = NUM_EDGES
    split_indices.append(edge_indices_targets)

    target = 1
    split_indices.append(target)

    for i in range(1, len(split_indices)):
        split_indices[i] = split_indices[i] + split_indices[i - 1]

    graph_input = lbann.Slice(_input,
                              axis=0,
                              slice_points=str_list(split_indices))

    neighbor_feature_dims = str_list([NUM_EDGES, 1, EMBEDDING_DIM])

    node_feature_columns = [
        lbann.Reshape(lbann.Identity(graph_input),
                      dims=str_list([NUM_NODES]),
                      name="node_ft_{}_col".format(x))
        for x in range(NUM_NODE_FEATURES)
    ]

    edge_feature_columns = [
        lbann.Reshape(lbann.Identity(graph_input),
                      dims=str_list([NUM_EDGES]),
                      name="edge_ft_{}_col".format(x))
        for x in range(NUM_EDGE_FEATURES)
    ]

    source_nodes = lbann.Reshape(lbann.Identity(graph_input),
                                 dims=str_list([NUM_EDGES]),
                                 name="source_nodes")
    target_nodes = lbann.Reshape(lbann.Identity(graph_input),
                                 dims=str_list([NUM_EDGES]),
                                 name="target_nodes")
    label = lbann.Reshape(lbann.Identity(graph_input),
                          dims=str_list([1]),
                          name="Graph_Label")

    embedded_node_features = AtomEncoder(node_feature_columns, EMBEDDING_DIM)

    embedded_edge_features = BondEncoder(edge_feature_columns,
                                         EDGE_EMBEDDING_DIM)

    neighbor_features = lbann.Gather(embedded_node_features,
                                     target_nodes,
                                     axis=0)
    neighbor_feature_mat = lbann.Reshape(neighbor_features,
                                         dims=neighbor_feature_dims)
    return \
    embedded_node_features, neighbor_feature_mat, embedded_edge_features, source_nodes, label
Esempio n. 5
0
    def compute_loss(self, x, y):

        # y[:, :-1]
        y = lbann.Slice(
            y,
            axis=0,
            slice_points=str_list([0, self.input_feature_dims-1]),
        )
        y = lbann.Identity(y)

        # x[:, 1:]
        x = lbann.Slice(
            x,
            slice_points=str_list([1, self.input_feature_dims]),
        )
        x = lbann.Identity(x)

        # Figure out entries in x to ignore
        ignore_mask = lbann.Equal(
            x,
            self.constant(self.label_to_ignore, hint_layer=x),
        )
        keep_mask = lbann.LogicalNot(ignore_mask)
        length = lbann.Reduction(keep_mask, mode='sum')
        length = lbann.Max(length, self.constant(1, [1]))

        # Convert entries in x to indices in y
        # Note: Ignored entries correspond to an index of -1.
        offsets = [
            row*self.dictionary_size
            for row in range(self.input_feature_dims-1)
        ]
        offsets = lbann.Weights(
            initializer=lbann.ValueInitializer(values=str_list(offsets)),
            optimizer=lbann.NoOptimizer(),
        )
        offsets = lbann.WeightsLayer(
            dims=str_list([self.input_feature_dims-1]),
            weights=offsets,
        )
        y_inds = lbann.Add(x, offsets)
        y_inds = lbann.Add(
            lbann.Multiply(keep_mask, y_inds),
            lbann.Multiply(
                ignore_mask,
                self.constant(-1, hint_layer=y_inds),
            ),
        )

        # recon_loss = F.cross_entropy(
        #     y[:, :-1].contiguous().view(-1, y.size(-1)),
        #     x[:, 1:].contiguous().view(-1),
        #     ignore_index=self.pad
        # )

        # Shift y for numerical stability
        # Note: We'd prefer to shift by y.max(-1)
        shifts = lbann.MatMul(
            lbann.Max(y, self.constant(0, hint_layer=y)),
            self.constant(
                1 / math.sqrt(self.dictionary_size),
                [self.dictionary_size, self.dictionary_size],
            ),
        )
        y = lbann.Subtract(y, shifts)

        # Compute log of softmax denominator and sum
        z = lbann.MatMul(
            lbann.Exp(y),
            self.constant(1, [self.dictionary_size, 1]),
        )
        z = lbann.Log(z)
        z = lbann.MatMul(
            lbann.Reshape(keep_mask, dims=str_list([1, -1])),
            z,
        )
        z = lbann.Reshape(z, dims=str_list([1]))

        # Compute cross entropy
        recon_loss = lbann.Gather(
            lbann.Reshape(y, dims=str_list([-1])),
            y_inds,
        )
        recon_loss = lbann.Reduction(recon_loss, mode='sum')
        recon_loss = lbann.Subtract(z, recon_loss)
        recon_loss = lbann.Divide(recon_loss, length)

        return recon_loss