def forward(self, pos, batch):
        radius = 0.2
        edge_index = radius_graph(pos, r=radius, batch=batch)
        x = F.relu(self.features[0](None, pos, edge_index))

        idx = fps(pos, batch, ratio=0.5)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        radius = 0.4
        edge_index = radius_graph(pos, r=radius, batch=batch)
        x = F.relu(self.features[1](x, pos, edge_index))

        idx = fps(pos, batch, ratio=0.25)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        radius = 1
        edge_index = radius_graph(pos, r=radius, batch=batch)
        x = F.relu(self.features[2](x, pos, edge_index))

        x = global_max_pool(x, batch)
        feat = x

        x = F.relu(self.classifier[0](x))
        x = F.relu(self.classifier[1](x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.classifier[2](x)

        x2 = F.relu(self.discriminator[0](feat))
        x2 = F.dropout(x2, p=0.5, training=self.training)
        x2 = self.discriminator[1](x2)
        return F.log_softmax(x, dim=-1), F.log_softmax(x2, dim=-1)
    def forward(self, pos, batch):
        radius = 0.2
        edge_index = radius_graph(pos, r=radius, batch=batch)
        x = F.relu(self.conv1(None, pos, edge_index))

        idx = fps(pos, batch, ratio=0.5)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        radius = 0.4
        edge_index = radius_graph(pos, r=radius, batch=batch)
        x = F.relu(self.conv2(x, pos, edge_index))

        idx = fps(pos, batch, ratio=0.25)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        radius = 1
        edge_index = radius_graph(pos, r=radius, batch=batch)
        x = F.relu(self.conv3(x, pos, edge_index))

        x = global_max_pool(x, batch)

        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin3(x)
        return F.log_softmax(x, dim=-1)
    def forward(self, pos, batch):
        x = pos.new_ones((pos.size(0), 1))

        radius = 0.2
        edge_index = radius_graph(pos, r=radius, batch=batch)
        pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5
        pseudo = pseudo.clamp(min=0, max=1)
        x = F.elu(self.conv1(x, edge_index, pseudo))

        idx = fps(pos, batch, ratio=0.5)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        radius = 0.4
        edge_index = radius_graph(pos, r=radius, batch=batch)
        pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5
        pseudo = pseudo.clamp(min=0, max=1)
        x = F.elu(self.conv2(x, edge_index, pseudo))

        idx = fps(pos, batch, ratio=0.25)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        radius = 1
        edge_index = radius_graph(pos, r=radius, batch=batch)
        pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5
        pseudo = pseudo.clamp(min=0, max=1)
        x = F.elu(self.conv3(x, edge_index, pseudo))

        x = global_mean_pool(x, batch)

        x = F.elu(self.lin1(x))
        x = F.elu(self.lin2(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin3(x)
        return F.log_softmax(x, dim=-1)
Exemple #4
0
    def get(self, idx):

        adj_all = torch.zeros((self.max_num_nodes, self.max_num_nodes),
                              dtype=torch.float)
        # transform
        if self.dynamic_graph:
            data = torch.load(osp.join(self.processed_root, self.idxlist[idx]))
            num_nodes = data.num_nodes
            dist_path = os.path.join(self.root, 'proto', 'distance',
                                     self.setting.dataset, self.idxlist[idx])
            choice, sample_num_node = self._sampling(num_nodes,
                                                     self.sampling_ratio,
                                                     dist_path)
            for key, item in data:
                if torch.is_tensor(item) and item.size(0) == num_nodes:
                    data[key] = item[choice]
            # generate the graph
            if self.graph_sampler == 'knn':
                edge_index = radius_graph(data.pos, self.max_edge_distance,
                                          None, True, self.max_neighbours)
                adj = sparse_to_dense(edge_index)
            else:
                raise NotImplementedError
            adj[adj > 0] = 1
            adj_all[:sample_num_node, :sample_num_node] = adj
        else:
            data = torch.load(
                osp.join(self.processed_fix_data_root, str(self.epoch),
                         self.idxlist[idx]))
            if self.graph_sampler == 'knn':
                edge_index = radius_graph(data.pos, self.max_edge_distance,
                                          None, True, self.max_neighbours)
                adj = sparse_to_dense(edge_index)
            else:
                raise NotImplementedError
            num_nodes = adj.shape[0]
            adj_all[:num_nodes, :num_nodes] = adj
        feature = data.x
        if self.feature_type == 'c':
            feature = feature[:, -2:]
        elif self.feature_type == 'a':
            feature = feature[:, :-2]
        num_feature = feature.shape[1]
        feature = (feature - self.mean) / self.std
        feature_all = torch.zeros((self.max_num_nodes, num_feature),
                                  dtype=torch.float)
        if self.dynamic_graph:
            feature_all[:sample_num_node] = feature
        else:
            feature_all[:num_nodes] = feature
        label = data.y
        idx = torch.tensor(idx)
        return {
            'adj': adj_all,
            'feats': feature_all,
            'label': label,
            'num_nodes': sample_num_node if self.dynamic_graph else num_nodes,
            'patch_idx': idx
        }
Exemple #5
0
    def forward(self, batch_data):
        z, pos, batch = batch_data.z, batch_data.pos, batch_data.batch
        if self.energy_and_force:
            pos.requires_grad_()
        edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
        num_nodes = z.size(0)
        dist, angle, torsion, i, j, idx_kj, idx_ji = xyztodat(pos,
                                                              edge_index,
                                                              num_nodes,
                                                              use_torsion=True)

        emb = self.emb(dist, angle, torsion, idx_kj)

        #Initialize edge, node, graph features
        e = self.init_e(z, emb, i, j)
        v = self.init_v(e, i, num_nodes=pos.size(0))
        u = self.init_u(torch.zeros_like(scatter(v, batch, dim=0)), v,
                        batch)  #scatter(v, batch, dim=0)

        for update_e, update_v, update_u in zip(self.update_es, self.update_vs,
                                                self.update_us):
            e = update_e(e, emb, idx_kj, idx_ji)
            v = update_v(e, i)
            u = update_u(u, v, batch)  #u += scatter(v, batch, dim=0)

        return u
Exemple #6
0
    def forward(self, pos, p, f=None, batch=None):
        N = pos.shape[0]
        # build graph
        # col, row = knn_graph(pos, k=self.kernel_size, batch=batch)
        col, row = radius_graph(pos,
                                r=self.radius,
                                batch=batch,
                                max_num_neighbors=self.kernel_size)

        u = -torch.log(p)  # unary

        # compute weights
        f = f.unsqueeze(0).repeat(self.num_kernels, 1, 1)
        f = torch.bmm(f, self.F)  # [K, N, H]
        f = f.permute((1, 0, 2))  # [N, K, H]
        f = f[col] - f[row]  # [E, K, H]
        w = torch.exp(-torch.sum(f**2, dim=-1))  # [E, K]
        w = torch.mm(w, self.W)  # [E, 1]

        # mean field steps
        q = p  # initialize q [N, L]
        for _ in range(self.steps):
            q = scatter_add(q[col] * w, row, dim=0,
                            dim_size=N)  # message passing [N, L]
            q = torch.mm(q, self.C)  # compatibility transformation [N, L]
            q = torch.softmax(-u - q, dim=-1)

        return q
    def forward(self, data, phase):
        # pdb.set_trace()
        x, edge_index, pseudo, batch = data.x, data.edge_index, data.pos, data.batch
        # radius = np.random.randint(1,4) if phase == 'train' else 4
        radius = 4
        edge_index = radius_graph(pseudo.float(), r=radius, batch=batch)
        x = self.lin01(x)

        x_g1 = F.relu(self.conv1(x, edge_index))
        x = self.graph_norm_1(x_g1, batch)

        x = F.dropout(x, p=0.8, training=phase == 'train')
        x_g2 = F.relu(self.conv2(x, edge_index))
        x = self.graph_norm_2(x_g2, batch)

        x = F.dropout(x, p=0.8, training=phase == 'train')
        x_g3 = F.relu(self.conv3(x, edge_index))
        x = self.graph_norm_3(x_g3, batch)

        # pdb.set_trace()
        x = self.pool(x, batch)
        x = self.lin2(x)
        x = self.lin3(x)
        # pdb.set_trace()
        return {'A2': x}
Exemple #8
0
    def forward(self, data):
        # Get node features
        if self.embedding.device != data.atomic_numbers.device:
            self.embedding = self.embedding.to(data.atomic_numbers.device)
        data.x = self.embedding[data.atomic_numbers.long() - 1]

        # Construct graph, get edge features
        data.edge_index = radius_graph(data.pos,
                                       r=self.cutoff,
                                       batch=data.batch)
        row, col = data.edge_index
        pos = data.pos
        if self.regress_forces:
            pos = pos.requires_grad_(True)
        data.edge_weight = (pos[row] - pos[col]).norm(dim=-1)
        data.edge_attr = self.distance_expansion(data.edge_weight)

        # Forward pass through the network
        mol_feats = self._convolve(data)
        mol_feats = self.conv_to_fc(mol_feats)
        if hasattr(self, "fcs"):
            mol_feats = self.fcs(mol_feats)

        energy = self.fc_out(mol_feats)
        if self.regress_forces:
            forces = -1 * (torch.autograd.grad(
                energy,
                pos,
                grad_outputs=torch.ones_like(energy),
                create_graph=True,
            )[0])
            return energy, forces
        else:
            return energy
    def __call__(self, sample):
        if self._visualization:
            keys = sorted([x for x in dir(sample) if 'edge_index' == x])
        else:
            keys = sorted([
                x for x in dir(sample)
                if 'edge_index' in x and 'dilated' not in x
            ])

        pos_keys = sorted([x for x in dir(sample) if 'pos' in x])

        for level, key in enumerate(keys):
            radius_edges = radius_graph(sample[pos_keys[level]],
                                        self._radius[level],
                                        max_num_neighbors=self._max_neigh)

            if not self._override:
                sample[key.replace('edge_index',
                                   'euclidean_edge_index')] = radius_edges
            else:
                # in a single conv architecture, we only have to keep track of one edge set
                sample[key] = radius_edges

        if not self._keep_pos:
            # we do not need to now actual spatial positions of later layers -> delete them
            keys = sorted([x for x in dir(sample) if 'pos_' in x])
            for key in keys:
                delattr(sample, key)

        return sample
Exemple #10
0
 def get(self, idx):
     # only support batch size = 1
     data = torch.load(osp.join(self.processed_root, self.idxlist[idx]))
     # generate the graph
     if self.graph_sampler == 'knn':
         edge_index = radius_graph(data.pos, self.max_edge_distance, None,
                                   True, self.max_neighbours)
         adj_all = sparse_to_dense(edge_index)
     else:
         raise NotImplementedError
     adj_all[adj_all > 0] = 1
     num_nodes = adj_all.shape[0]
     feature = data.x
     if self.feature_type == 'c':
         feature = feature[:, -2:]
     elif self.feature_type == 'a':
         feature = feature[:, :-2]
     feature = (feature - self.mean) / self.std
     label = data.y
     idx = torch.tensor([idx])
     return {
         'adj': adj_all,
         'feats': feature,
         'label': label,
         'num_nodes': num_nodes if self.dynamic_graph else num_nodes,
         'patch_idx': idx
     }
Exemple #11
0
    def forward(self, atomic_ns, coords, batch_node_vec):

        # processing sequence
        #   dist -> rbf
        #   dist+angle -> sbf
        #
        #   rbf -> emb
        #   rbf + sbf -> interaction
        #
        #   sum(emb, interactions) -> output mlp

        # functions needed
        #   interatomic dist
        #   neighbour angles from triplets

        # dist to rbf
        edge_index = radius_graph(coords, self.cutoff_val, batch_node_vec)
        num_nodes = atomic_ns.size(0)

        dists, angles, node_is, node_js, kj, ji = xyz_to_dg(
            coords, edge_index, num_nodes)

        dist_embs, angle_embs = self.emb(dists, angles, kj)

        # init edge, node, graph feats
        #edge_attr = self.

        return
Exemple #12
0
    def forward(self, z, pos, batch=None):
        """"""
        edge_index = radius_graph(pos, r=self.cutoff, batch=batch,
                                  max_num_neighbors=self.max_num_neighbors)

        i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets(
            edge_index, num_nodes=z.size(0))

        # Calculate distances.
        dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()

        # Calculate angles.
        pos_i = pos[idx_i]
        pos_ji, pos_ki = pos[idx_j] - pos_i, pos[idx_k] - pos_i
        a = (pos_ji * pos_ki).sum(dim=-1)
        b = torch.cross(pos_ji, pos_ki).norm(dim=-1)
        angle = torch.atan2(b, a)

        rbf = self.rbf(dist)
        sbf = self.sbf(dist, angle, idx_kj)

        # Embedding block.
        x = self.emb(z, rbf, i, j)
        P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0))

        # Interaction blocks.
        for interaction_block, output_block in zip(self.interaction_blocks,
                                                   self.output_blocks[1:]):
            x = interaction_block(x, rbf, sbf, idx_kj, idx_ji)
            P += output_block(x, rbf, i)

        return P.sum(dim=0) if batch is None else scatter(P, batch, dim=0)
def gen_graph(
    node_feature_path: str,
    node_feature_cols: List[str],
    label: int,
    max_neighbours: int = 8,
    radius: float = 50,
) -> Data:
    """ Generates graph from node features as pytorch object
    Arguments:
        node_feature_path {str} -- Path to node features for tile
        node_feature_cols {List[str]} -- list of features to use
        label {int} -- 1, or 0 for cancer/no-cancer
    Keyword Arguments:
        max_neighbours {int} -- k-max neighbors for graph edge (default: {8})
        radius {float} -- max pixel radius for graph edge (default: {50})
    Returns:
        Data object for use with pytorch geometric graph convolution
    """
    features = pd.read_csv(
        node_feature_path,
        converters={
            "diagnostics_Mask-original_CenterOfMass": ast.literal_eval
        }
    )

    # Normalize Features
    f_df = features[node_feature_cols]
    f_norm_df = (f_df - f_df.mean()) / (f_df.max() - f_df.min())
    f_norm_nafilled = f_norm_df.fillna(0)

    # Set centroid coordinates from node features
    coordinates = torch.tensor(
        features['diagnostics_Mask-original_CenterOfMass'].tolist()
    )

    # Set node features from normalized features
    node_features = torch.tensor(
        f_norm_nafilled.astype(float).values,
        dtype=torch.float32
    )

    # Initialize data for pytorch
    y = torch.tensor([label], dtype=torch.long)
    data = Data(
        x=node_features,
        pos=coordinates,
        y=y,
        num_classes=2
    )

    # Create Graph
    data.edge_index = radius_graph(
        data.pos,
        radius,
        None,
        True,
        max_neighbours
    )

    return data
Exemple #14
0
def pt_to_gexf(numpyfile,
               savepath,
               sample=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1]):
    # data = torch.load(torchfile)
    # feature = data.x.numpy()
    coordinates = np.load(numpyfile)
    coordinates = coordinates / 1500.
    coordinates = torch.from_numpy(coordinates).to(torch.float)
    num_nodes = coordinates.shape[0]
    for sampling_ratio in sample:
        num_sample = int(num_nodes * sampling_ratio)
        choice = np.random.choice(num_nodes, num_sample, replace=True)
        coordinates_new = coordinates[choice]

        edge_index = radius_graph(coordinates_new, 100, None, True, 8)
        adj = sparse_to_dense(edge_index)
        adj = adj.numpy()
        coordinates_new = coordinates_new.numpy()
        G = nx.from_numpy_matrix(adj)
        x = dict(enumerate(coordinates_new[:, 0].tolist(), 0))
        y = dict(enumerate(coordinates_new[:, 1].tolist(), 0))
        nx.set_node_attributes(G, x, 'x')
        nx.set_node_attributes(G, y, 'y')
        nx.write_gexf(
            G,
            os.path.join(
                savepath,
                str(sampling_ratio) +
                numpyfile.split('/')[-1].replace('.npy', '.gexf')))
Exemple #15
0
 def get(self, idx):
     epoch = self.epoch if self.split == 'train' else self.val_epoch
     if self.dynamic_graph:
         data = torch.load(osp.join(self.processed_root, self.idxlist[idx]))
     else:
         data = torch.load(
             osp.join(self.processed_fix_data_root, str(epoch),
                      self.idxlist[idx]))
     if self.feature_type == 'c':
         data.x = data.x[:, -2:]
     elif self.feature_type == 'a':
         data.x = data.x[:, :-2]
     if self.dynamic_graph:
         num_nodes = data.num_nodes
         dist_path = os.path.join(self.root, 'proto', 'distance',
                                  self.setting.dataset, self.idxlist[idx])
         choice, _ = self._sampling(num_nodes, self.sampling_ratio,
                                    dist_path)
         for key, item in data:
             if torch.is_tensor(item) and item.size(0) == num_nodes:
                 data[key] = item[choice]
     if self.graph_sampler == 'knn':
         edge_index = radius_graph(data.pos, self.max_edge_distance, None,
                                   True, self.max_neighbours)
         data.edge_index = edge_index
     else:
         raise NotImplementedError
     data.patch_idx = torch.tensor([idx])
     data.x = (data.x - self.mean) / self.std
     return data
 def __call__(self, data):
     edge_index = radius_graph(data.pos,
                               self.radius,
                               None,
                               max_num_neighbors=64)
     edge_index = to_undirected(edge_index, num_nodes=data.pos.size(0))
     data.edge_index = edge_index
     return data
Exemple #17
0
    def forward(self, data):
        self.device = data.pos.device
        self.num_atoms = len(data.batch)
        self.batch_size = len(data.natoms)

        atomic_numbers = data.atomic_numbers.long()
        pos = data.pos
        if self.regress_forces:
            pos = pos.requires_grad_(True)

        if self.otf_graph:
            edge_index, cell_offsets, neighbors = radius_graph_pbc(
                data, self.cutoff, 100)
            data.edge_index = edge_index
            data.cell_offsets = cell_offsets
            data.neighbors = neighbors

        if self.use_pbc:
            assert (atomic_numbers.dim() == 1
                    and atomic_numbers.dtype == torch.long)

            out = get_pbc_distances(
                pos,
                data.edge_index,
                data.cell,
                data.cell_offsets,
                data.neighbors,
                return_distance_vec=True,
            )

            edge_index = out["edge_index"]
            edge_distance = out["distances"]
            edge_distance_vec = out["distance_vec"]

        else:
            edge_index = radius_graph(pos, r=self.cutoff, batch=data.batch)
            j, i = edge_index
            edge_distance_vec = pos[j] - pos[i]
            edge_distance = edge_distance_vec.norm(dim=-1)

        edge_index, edge_distance, edge_distance_vec = self._filter_edges(
            edge_index,
            edge_distance,
            edge_distance_vec,
            self.max_num_neighbors,
        )

        outputs = self._forward_helper(data, edge_index, edge_distance,
                                       edge_distance_vec)
        if self.show_timing_info is True:
            torch.cuda.synchronize()
            print("Memory: {}\t{}\t{}".format(
                len(edge_index[0]),
                torch.cuda.memory_allocated() / (1000 * len(edge_index[0])),
                torch.cuda.max_memory_allocated() / 1000000,
            ))

        return outputs
    def __call__(self, data):
        data.edge_attr = None
        batch = data.batch if 'batch' in data else None
        edge_index = radius_graph(data.pos, self.r, batch, self.loop,
                                  self.max_num_neighbors)

        data.edge_index = edge_index

        return data
Exemple #19
0
def generate_random_graph(n, radius, max_num_neighbors, batch_size, device):
    # arrange for points to fit more inside the unit square
    batch = None
    num_nodes = th.tensor([n], device=device, dtype=th.long)
    x = 0.90 * (th.rand(n, 2, device=device) + 0.05)
    edge_index = radius_graph(
        x, radius, batch=batch, loop=True,
        max_num_neighbors=max_num_neighbors)  # COO format
    return x, edge_index, num_nodes, batch
Exemple #20
0
    def forward(self, data):
        # Get node features
        if self.embedding.device != data.atomic_numbers.device:
            self.embedding = self.embedding.to(data.atomic_numbers.device)
        data.x = self.embedding[data.atomic_numbers.long() - 1]

        pos = data.pos
        if self.regress_forces:
            pos = pos.requires_grad_(True)

        if self.otf_graph:
            edge_index, cell_offsets, neighbors = radius_graph_pbc(
                data, self.cutoff, 50, data.pos.device
            )
            data.edge_index = edge_index
            data.cell_offsets = cell_offsets
            data.neighbors = neighbors

        if self.use_pbc:
            out = get_pbc_distances(
                pos,
                data.edge_index,
                data.cell,
                data.cell_offsets,
                data.neighbors,
            )

            data.edge_index = out["edge_index"]
            distances = out["distances"]
        else:
            data.edge_index = radius_graph(
                data.pos, r=self.cutoff, batch=data.batch
            )
            row, col = data.edge_index
            distances = (pos[row] - pos[col]).norm(dim=-1)

        data.edge_attr = self.distance_expansion(distances)
        # Forward pass through the network
        mol_feats = self._convolve(data)
        mol_feats = self.conv_to_fc(mol_feats)
        if hasattr(self, "fcs"):
            mol_feats = self.fcs(mol_feats)

        energy = self.fc_out(mol_feats)
        if self.regress_forces:
            forces = -1 * (
                torch.autograd.grad(
                    energy,
                    pos,
                    grad_outputs=torch.ones_like(energy),
                    create_graph=True,
                )[0]
            )
            return energy, forces
        else:
            return energy
Exemple #21
0
 def forward(self, data):
     # mesh = data[0]
     mesh = data
     if self.configs['euc_radius'] > 0:
         mesh.euc_edge_index = radius_graph(mesh.pos, self.configs['euc_radius'])
     else:
         mesh.euc_edge_index = None
     mesh_feat, global_feat = self.edgeconv_feat(mesh)
     skin_logits = self.skinnet(mesh_feat)
     return skin_logits
Exemple #22
0
    def forward(self, z, pos, batch=None):
        assert z.dim() == 1 and z.dtype == torch.long
        batch = torch.zeros_like(z) if batch is None else batch

        h = self.embedding(z)

        edge_index = radius_graph(pos,
                                  r=self.cutoff,
                                  batch=batch,
                                  max_num_neighbors=1000)
        row, col = edge_index
        edge_vec = pos[row] - pos[col]
        edge_sh = o3.spherical_harmonics(self.Rs_sh, edge_vec,
                                         'component') / self.num_neighbors**0.5
        edge_len = edge_vec.norm(dim=1)
        edge_weight = self.radial(edge_len)
        edge_c = (pi * edge_len / self.cutoff).cos().add(1).div(2)

        for conv, act, shortcut in self.layers[:-1]:
            with torch.autograd.profiler.record_function("Layer"):
                if shortcut:
                    s = shortcut(h)

                h = conv(h, edge_index, edge_weight, edge_c,
                         edge_sh)  # convolution
                h = act(h)  # gate non linearity

                if shortcut:
                    m = shortcut.output_mask
                    h = 0.5**0.5 * s + (1 + (0.5**0.5 - 1) * m) * h

        with torch.autograd.profiler.record_function("Layer"):
            h = self.layers[-1](h, edge_index, edge_weight, edge_c, edge_sh)

        s = 0
        for i, (mul, l, p) in enumerate(self.Rs_out):
            assert mul == 1 and l == 0
            if p == 1:
                s += h[:, i]
            if p == -1:
                s += h[:, i].pow(2).mul(0.5)  # odd^2 = even
        h = s.view(-1, 1)

        if self.mean is not None and self.std is not None:
            h = h * self.std + self.mean

        if self.atomref is not None:
            h = h + self.atomref(z)

        out = scatter(h, batch, dim=0, reduce=self.readout)

        if self.scale is not None:
            out = self.scale * out

        return out
Exemple #23
0
    def __call__(self, data):

        edge_index = radius_graph(data.pos, r=self.cutoff, batch=data.batch)
        edge_index, _ = remove_self_loops(edge_index)
        row, col = edge_index[0], edge_index[1]
        edge_weight = (data.pos[row] - data.pos[col]).norm(dim=-1)
        if hasattr(data, "edge_attr"):
            pass
        else:
            data.edge_attr = edge_weight.reshape(-1, 1)
        return data
Exemple #24
0
    def forward(self, atomic_ns, edge_index, coords, batch_node_vec):
        
        # create initial node embs from atom charges
        node_embs = self.node_emb(atomic_ns)
        
        # create initial edge embs from coordinates
        edge_index = radius_graph(coords, self.cutoff, batch_node_vec)
        node_is, node_js = edge_index
        edge_weights = (coords[node_is] - coords[node_js]).norm(dim = -1)
        edge_embs = self.dist_emb(edge_weights) # edge_embs = interatomic dist embs

        return node_embs, edge_embs, edge_weights
 def forward(self, data, phase):
     # pdb.set_trace()
     x, edge_index, pseudo, batch = data.x, data.edge_index, data.pos, data.batch
     radius = 4
     edge_index = radius_graph(pseudo.float(), r=radius, batch=batch)
     x = self.lin01(x)
     x_g1 = F.relu(self.conv1(x, edge_index))
     x = self.graph_norm_1(x_g1, batch)
     # pdb.set_trace()
     x = global_mean_pool(x, batch)
     x = self.lin2(x)
     x = self.lin3(x)
     # pdb.set_trace()
     return {'A2': x}
Exemple #26
0
    def forward(self, data):
        pos = data.pos
        if self.regress_forces:
            pos = pos.requires_grad_(True)
        batch = data.batch
        x = self.embedding(data.atomic_numbers.long())
        edge_index = radius_graph(pos, r=self.cutoff, batch=batch)

        j, i = edge_index
        idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets(
            edge_index, num_nodes=x.size(0))

        # Calculate distances.
        dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()

        # Calculate angles.
        pos_i = pos[idx_i].detach()
        pos_ji, pos_ki = (
            pos[idx_j].detach() - pos_i,
            pos[idx_k].detach() - pos_i,
        )
        a = (pos_ji * pos_ki).sum(dim=-1)
        b = torch.cross(pos_ji, pos_ki).norm(dim=-1)
        angle = torch.atan2(b, a)

        rbf = self.rbf(dist)
        sbf = self.sbf(dist, angle, idx_kj)

        # Embedding block.
        x = self.emb(x, rbf, i, j)
        P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0))

        # Interaction blocks.
        for interaction_block, output_block in zip(self.interaction_blocks,
                                                   self.output_blocks[1:]):
            x = interaction_block(x, rbf, sbf, idx_kj, idx_ji)
            P += output_block(x, rbf, i)

        energy = P.sum(dim=0) if batch is None else scatter(P, batch, dim=0)

        if self.regress_forces:
            forces = -1 * (torch.autograd.grad(
                energy,
                pos,
                grad_outputs=torch.ones_like(energy),
                create_graph=True,
            )[0])
            return energy, forces
        else:
            return energy
    def __call__(self, data):
        #pdb.set_trace()
        nodes = data.x
        #global_node = torch.ones([1, data.x.size()[1]])
        global_node = nodes.mean(dim = 0).unsqueeze(0)

        data.x = torch.cat([global_node, nodes])
        pos = data.pos
        #pdb.set_trace()
        edge_index = radius_graph(pos, r = self.radius)
        zero_index =  torch.stack((torch.arange(data.x.size()[0]), torch.zeros(data.x.size()[0]).long()))[:,1:]
        edge_index = torch.cat((zero_index, edge_index+1), dim = 1)
        data.pos = torch.cat((torch.zeros(1, 2).int(), pos+1))
        data.edge_index = edge_index
        return data
 def forward(self, data, phase):
     # pdb.set_trace()
     x, edge_index, pseudo, batch = data.x, data.edge_index, data.pos, data.batch
     # radius = np.random.randint(1,4) if phase == 'train' else 4
     radius = 4
     edge_index = radius_graph(pseudo.float(), r=radius, batch=batch)
     x = self.lin01(x)
     x = F.relu(self.conv1(x, edge_index))
     x = self.graph_norm_1(x, batch)
     # pdb.set_trace()
     x = global_mean_pool(x, batch)
     x = self.lin2(x)
     x = self.lin3(x)
     # pdb.set_trace()
     return {'A2': x}
Exemple #29
0
 def __call__(self, data, sigma=0.01, clip=0.05):
     """
     Randomly jitter points. jittering is per point.
     :param pc: B X N X 3 array, original batch of point clouds
     :param sigma:
     :param clip:
     :return:
     """
     pc = data.pos
     jittered_data = torch.from_numpy(
         np.clip(sigma * np.random.randn(*pc.shape), -1 * clip, clip))
     jittered_data += pc
     data.pos = jittered_data
     data.edge_index = nn.radius_graph(x=torch.tensor(pc), r=0.05)
     return data
Exemple #30
0
    def __call__(self, data: Data)->Data:
        data.edge_attr = None
        nans = torch.isnan(data.pos[:, 0])
        pos = data.pos.clone()
        pos[nans] = 10_000
        batch = data.batch if "batch" in data else None

        edge_index = radius_graph(
            pos, self.r, batch, self.loop, self.max_num_neighbors,
        )
        drop = torch.any(edge_index[0] == torch.nonzero(nans), dim=0) | torch.any(
            edge_index[1] == torch.nonzero(nans), dim=0
        )

        data.edge_index = edge_index[:, ~drop]
        return data