コード例 #1
0
 def forward(self, G, type_0_features, type_1_features):
     # Compute equivariant weight basis from relative positions
     basis, r = get_basis_and_r(G, self.num_degrees - 1)
     h = {'0': type_0_features, '1': type_1_features}
     for layer in self.block0:
         h = layer(h, G=G, r=r, basis=basis)
     return h
コード例 #2
0
ファイル: nbody_models.py プロジェクト: jam-ing/Transformers
    def forward(self, G):
        # Compute equivariant weight basis from relative positions
        basis, r = get_basis_and_r(G, self.num_degrees-1)
        h_enc = {'1': G.ndata['v']}
        for layer in self.Gblock:
            h_enc = layer(h_enc, G=G, r=r, basis=basis)

        return h_enc['1']
コード例 #3
0
def collate(samples, num_degrees):
    graphs, y = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    # Compute equivariant weight basis from relative positions
    basis, r = get_basis_and_r(batched_graph, num_degrees)
    batched_graph.edata['feat'] = torch.cat([batched_graph.edata['w'], r], -1)

    return batched_graph, torch.tensor(y), basis
コード例 #4
0
    def forward(self, G, save_steps=False):
        G_steps = []

        # We need to keep a copy of the initial position of all nodes in order to calculate the overall change in
        # position between the input and the output. This information is used for logging purposes.
        original_x = torch.clone(G.ndata['x'])

        # 'features' are the node inputs to every layer; this dataset comes with edge features but without node features
        # hence, in the first layer, we need dummy node features; we choose this to be a single [1] for each node
        features = {'0': G.ndata['ones']}

        for j, inner_block in enumerate(self.blocks):
            if save_steps:
                G_steps.append(copy_dgl_graph(G))  # for logging only

            num_nodes = _get_number_of_nodes_per_batch_element(G)
            if self.k_neighbors is not None and self.k_neighbors < num_nodes:
                network_input_graph = copy_without_weak_connections(
                    G, K=self.k_neighbors)
            else:
                network_input_graph = G

            # Compute equivariant weight basis for the current graph
            basis, r = get_basis_and_r(
                network_input_graph,
                self.num_degrees - 1,
                compute_gradients=self.compute_gradients)

            if torch.min(r) < 1e-5:
                warnings.warn(
                    "Minimum separation between nodes fell below 1e-5")

            for layer in inner_block:
                features = layer(features,
                                 G=network_input_graph,
                                 r=r,
                                 basis=basis)

            # We arbitrarily use the first type-1 feature for the position updates.
            position_updates = features['1'][:, 0:1, :]
            G.ndata['x'] = G.ndata['x'] + position_updates

            update_relative_positions(G)
            update_potential_values(G)

            # Track the updates at each iteration and store on the graph for logging.
            G.ndata[f'update_norm_{j}'] = torch.sqrt(
                torch.sum(position_updates**2, -1, keepdim=True))

        # Calculate the overall update and store on the graph for logging.
        overall_update = G.ndata['x'] - original_x
        G.ndata['update_norm'] = torch.sqrt(
            torch.sum(overall_update**2, -1, keepdim=True))

        if save_steps:
            G_steps.append(copy_dgl_graph(G))

        return G, G_steps
コード例 #5
0
    def gcn_graph(self, G):
        # Compute equivariant weight basis from relative positions
        basis, r = get_basis_and_r(G, self.num_degrees - 1)

        # encoder (equivariant layers)
        h = {'0': G.ndata['f']}
        for layer in self.Gblock:
            h = layer(h, G=G, r=r, basis=basis)

        return h
コード例 #6
0
ファイル: models_xyz.py プロジェクト: vwslz/se3_transformer
    def forward(self, G):
        # Compute equivariant weight basis from relative positions
        basis, r = get_basis_and_r(G, self.num_degrees - 1)

        # encoder (equivariant layers)
        h = {
            '0': G.ndata['f'],
            '1': torch.stack([G.ndata['x_n'], G.ndata['x_c']], dim=1)
        }
        for layer in self.Gblock:
            h = layer(h, G=G, r=r, basis=basis)

        for layer in self.FCblock:
            h = layer(h, G)

        return h
コード例 #7
0
    def forward(self, G):
        # Compute equivariant weight basis from relative positions
        basis, r = get_basis_and_r(G, self.num_degrees-1)

        # encoder (equivariant layers)
        h = {'0': G.ndata['feat']}
        for layer in self.block0:
            h = layer(h, G=G, r=r, basis=basis)

        for layer in self.block1:
            h = layer(h, G)

        for layer in self.block2:
            h = layer(h)

        return h