def forward(self, x, dims): """Apply fftshift. Args: x (lbann.Layer): Input tensor dims (tuple of int): Dimensions of x (dim 0 corresponds to channel) Returns: Layer: Output tensor """ # Get gather indices by applying fftshift to tensor filled with indices # Note: Independent fftshift for each channel (dim 0) spatial_size = np.prod(dims[1:]) spatial_inds = np.arange(spatial_size).reshape(dims[1:]) spatial_inds = np.fft.fftshift(spatial_inds) channel_offsets = np.arange(0, dims[0] * spatial_size, spatial_size) channel_offsets = channel_offsets.reshape([-1] + [1] * spatial_inds.ndim) inds = np.expand_dims(spatial_inds, 0) + channel_offsets # Construct LBANN layer graph size = np.prod(dims) x = lbann.Reshape(x, dims=str_list([size])) inds = lbann.WeightsLayer( weights=lbann.Weights( lbann.ValueInitializer(values=str_list(inds.flatten())), optimizer=lbann.NoOptimizer(), ), dims=str_list([size]), ) y = lbann.Gather(x, inds) return lbann.Reshape(y, dims=str_list(dims))
def Permute(x, dims, axes=None, name="", return_dims=False): global _permute_cache key = (dims, axes) size = np.prod(dims) if key not in _permute_cache: # Construct gather indices inds = np.arange(size).reshape(dims, order="C").transpose(axes) inds = lbann.Weights( initializer=lbann.ValueInitializer(values=str_list( np.nditer(inds, order="C")), ), optimizer=lbann.NoOptimizer(), ) inds = lbann.WeightsLayer(dims=str_list([size]), weights=inds) _permute_cache[key] = inds # Apply transpose with gather inds = _permute_cache[key] if axes == None: new_dims = dims[::-1] else: new_dims = np.array(dims)[list(axes)] x = lbann.Reshape(x, dims=str_list([size])) y = lbann.Gather(x, inds) y = lbann.Reshape(y, dims=str_list(list(new_dims)), name=name) if return_dims: return y, tuple(new_dims) return y
def GraphExpand(features, indices, name=None): """Places the features according the indices to an expanded matrix output[i] = features[indices[i]] Args: features (Layer) : 2D matrix with shape (N, F) indices (Layer): 1D matrix with shape (E) returnL (Layer) of shape (E,F) """ GraphExpand.count += 1 if (name is None): name = f"graph_expand_{GraphExpand.count}" return lbann.Gather(features, indices, axis=0, name=name)
def graph_data_splitter(_input, NUM_NODES, NUM_EDGES, NUM_NODE_FEATURES, NUM_EDGE_FEATURES, EMBEDDING_DIM, EDGE_EMBEDDING_DIM): """Helper function to split the input data into Args: NUM_NODES (int): The number of nodes in the largest graph in the dataset (51 for LSC-PPQM4M) NUM_EDGES (int): The number of edges in the largest graph in the dataset (118 for LSC-PPQM4M) NUM_NODE_FEATURES (int): The dimensionality of the input node features vector (9 for LSC-PPQM4M) NUM_EDGE_FEATURES (int): The dimensionality of the input edge feature vectors (3 for LSC-PPQM4M) EMBEDDING_DIM (int): The embedding dimensionality of the node feature vector EDGE_EMBEDDING_DIM (int): The embedding dimensionality of the edge feature vector Returns: (Layer, Layer, Layer, Layer, Layer): Returns 5 Layers. The embedded node feature matrix, the neighbord nodes feature tensor, the embedded edge feature matrix, the source node index vector, and the label """ split_indices = [] start_index = 0 split_indices.append(start_index) node_feature = [NUM_NODES for i in range(1, NUM_NODE_FEATURES + 1)] split_indices.extend(node_feature) edge_features = [NUM_EDGES for i in range(1, NUM_EDGE_FEATURES + 1)] split_indices.extend(edge_features) edge_indices_sources = NUM_EDGES split_indices.append(edge_indices_sources) edge_indices_targets = NUM_EDGES split_indices.append(edge_indices_targets) target = 1 split_indices.append(target) for i in range(1, len(split_indices)): split_indices[i] = split_indices[i] + split_indices[i - 1] graph_input = lbann.Slice(_input, axis=0, slice_points=str_list(split_indices)) neighbor_feature_dims = str_list([NUM_EDGES, 1, EMBEDDING_DIM]) node_feature_columns = [ lbann.Reshape(lbann.Identity(graph_input), dims=str_list([NUM_NODES]), name="node_ft_{}_col".format(x)) for x in range(NUM_NODE_FEATURES) ] edge_feature_columns = [ lbann.Reshape(lbann.Identity(graph_input), dims=str_list([NUM_EDGES]), name="edge_ft_{}_col".format(x)) for x in range(NUM_EDGE_FEATURES) ] source_nodes = lbann.Reshape(lbann.Identity(graph_input), dims=str_list([NUM_EDGES]), name="source_nodes") target_nodes = lbann.Reshape(lbann.Identity(graph_input), dims=str_list([NUM_EDGES]), name="target_nodes") label = lbann.Reshape(lbann.Identity(graph_input), dims=str_list([1]), name="Graph_Label") embedded_node_features = AtomEncoder(node_feature_columns, EMBEDDING_DIM) embedded_edge_features = BondEncoder(edge_feature_columns, EDGE_EMBEDDING_DIM) neighbor_features = lbann.Gather(embedded_node_features, target_nodes, axis=0) neighbor_feature_mat = lbann.Reshape(neighbor_features, dims=neighbor_feature_dims) return \ embedded_node_features, neighbor_feature_mat, embedded_edge_features, source_nodes, label
def compute_loss(self, x, y): # y[:, :-1] y = lbann.Slice( y, axis=0, slice_points=str_list([0, self.input_feature_dims-1]), ) y = lbann.Identity(y) # x[:, 1:] x = lbann.Slice( x, slice_points=str_list([1, self.input_feature_dims]), ) x = lbann.Identity(x) # Figure out entries in x to ignore ignore_mask = lbann.Equal( x, self.constant(self.label_to_ignore, hint_layer=x), ) keep_mask = lbann.LogicalNot(ignore_mask) length = lbann.Reduction(keep_mask, mode='sum') length = lbann.Max(length, self.constant(1, [1])) # Convert entries in x to indices in y # Note: Ignored entries correspond to an index of -1. offsets = [ row*self.dictionary_size for row in range(self.input_feature_dims-1) ] offsets = lbann.Weights( initializer=lbann.ValueInitializer(values=str_list(offsets)), optimizer=lbann.NoOptimizer(), ) offsets = lbann.WeightsLayer( dims=str_list([self.input_feature_dims-1]), weights=offsets, ) y_inds = lbann.Add(x, offsets) y_inds = lbann.Add( lbann.Multiply(keep_mask, y_inds), lbann.Multiply( ignore_mask, self.constant(-1, hint_layer=y_inds), ), ) # recon_loss = F.cross_entropy( # y[:, :-1].contiguous().view(-1, y.size(-1)), # x[:, 1:].contiguous().view(-1), # ignore_index=self.pad # ) # Shift y for numerical stability # Note: We'd prefer to shift by y.max(-1) shifts = lbann.MatMul( lbann.Max(y, self.constant(0, hint_layer=y)), self.constant( 1 / math.sqrt(self.dictionary_size), [self.dictionary_size, self.dictionary_size], ), ) y = lbann.Subtract(y, shifts) # Compute log of softmax denominator and sum z = lbann.MatMul( lbann.Exp(y), self.constant(1, [self.dictionary_size, 1]), ) z = lbann.Log(z) z = lbann.MatMul( lbann.Reshape(keep_mask, dims=str_list([1, -1])), z, ) z = lbann.Reshape(z, dims=str_list([1])) # Compute cross entropy recon_loss = lbann.Gather( lbann.Reshape(y, dims=str_list([-1])), y_inds, ) recon_loss = lbann.Reduction(recon_loss, mode='sum') recon_loss = lbann.Subtract(z, recon_loss) recon_loss = lbann.Divide(recon_loss, length) return recon_loss