Exemplo n.º 1
0
    def forward(self, data):
        if self.layer_num == 0:
            return data.x, 0, torch.zeros_like(data.x[:, 0:1])
        x, batch = data.x, data.batch
        kwargs = {k: v for k, v in data.__dict__.items()}
        kwargs.pop('x')
        new_x = x

        left_confidence = torch.ones_like(x[:, 0:1])
        residual_confidence = torch.ones_like(x[:, 0:1])
        zero_mask = torch.zeros_like(x[:, 0:1])
        for iter_num in range(self.layer_num):
            data = Batch(x=self.next_x(x, new_x, left_confidence,
                                       self.decreasing_ratio),
                         **kwargs)
            new_x = self.gnn_layer_module(data)
            global_feat = self.readout_module(Batch(x=new_x, **kwargs))
            current_confidence = self.confidence_module(global_feat)[batch]

            left_confidence = left_confidence - current_confidence * (
                1 - zero_mask)
            current_zero_mask = (left_confidence < 1e-7).type(torch.float)
            residual_confidence = residual_confidence - current_confidence * (
                1 - current_zero_mask)
            x = x + (current_confidence * (1 - current_zero_mask) +
                     residual_confidence * current_zero_mask *
                     (1 - zero_mask)) * new_x
            zero_mask = current_zero_mask
            if torch.min(zero_mask).item() > 0.5:
                break
        return x, iter_num, residual_confidence
Exemplo n.º 2
0
    def forward(self, data):
        if self.layer_num == 0:
            return data.x, 0
        x, batch = data.x, data.batch
        kwargs = {k: v for k, v in data.__dict__.items()}
        kwargs.pop('x')
        new_x = x

        left_confidence = torch.ones_like(x[:, 0:1])
        for iter_num in range(self.layer_num):
            if torch.max(left_confidence).item() > 1e-7:
                data = Batch(x=self.next_x(x, new_x, left_confidence,
                                           self.decreasing_ratio),
                             **kwargs)
                new_x = self.gnn_layer_module(data)
                global_feat = self.readout_module(Batch(x=new_x, **kwargs))
                current_confidence = self.confidence_module(global_feat)[batch]
                x = self.update_x(x if iter_num != 0 else torch.zeros_like(x),
                                  new_x, left_confidence, current_confidence,
                                  self.decreasing_ratio)
                left_confidence = self.update_confidence(
                    left_confidence, current_confidence, self.decreasing_ratio)
            else:
                break

        return x, iter_num
Exemplo n.º 3
0
    def update(self, memory):
        # Monte Carlo estimate of rewards:
        rewards = []
        discounted_reward = 0
        for reward, terminal in zip(reversed(memory.rewards),
                                    reversed(memory.terminals)):
            if terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # Normalizing the rewards:
        rewards = torch.tensor(rewards).to(self.device)

        # candidates batch
        batch_idx = []
        for i, cands in enumerate(memory.candidates):
            batch_idx.extend([i] * len(cands))
        batch_idx = torch.LongTensor(batch_idx).to(self.device)

        # convert list to tensor
        states = [
            Batch().from_data_list([state[i] for state in memory.states
                                    ]).to(self.device)
            for i in range(1 + self.use_3d)
        ]
        states_next = [
            Batch().from_data_list(
                [state_next[i]
                 for state_next in memory.states_next]).to(self.device)
            for i in range(1 + self.use_3d)
        ]
        candidates = [
            Batch().from_data_list(
                [item[i] for sublist in memory.candidates
                 for item in sublist]).to(self.device)
            for i in range(1 + self.use_3d)
        ]
        actions = torch.tensor(memory.actions).to(self.device)

        old_logprobs = torch.tensor(memory.logprobs).to(self.device)
        old_values = self.policy.get_value(states)

        # Optimize policy for k epochs:
        logging.info("Optimizing...")

        for i in range(1, self.k_epochs + 1):
            loss, baseline_loss = self.policy.update(states, candidates,
                                                     actions, rewards,
                                                     old_logprobs, old_values,
                                                     batch_idx)
            rnd_loss = self.explore_critic.update(states_next)
            if (i % 10) == 0:
                logging.info(
                    "  {:3d}: Actor Loss: {:7.3f}, Critic Loss: {:7.3f}, RND Loss: {:7.3f}"
                    .format(i, loss, baseline_loss, rnd_loss))
Exemplo n.º 4
0
 def set_input(self, data, device):
     self.input = Batch(pos=data.pos, x=data.x, batch=data.batch).to(device)
     if hasattr(data, "pos_target"):
         self.input_target = Batch(pos=data.pos_target,
                                   x=data.x_target,
                                   batch=data.batch_target).to(device)
         self.match = data.pair_ind.to(torch.long).to(device)
         self.size_match = data.size_pair_ind.to(torch.long).to(device)
     else:
         self.match = data.pair_ind.to(torch.long).to(device)
         self.size_match = data.size_pair_ind.to(torch.long).to(device)
Exemplo n.º 5
0
def to_data_list(data):
    if 'to_data_list' in data.__dict__:
        return data.to_data_list()
    graph_indexes = set(data.batch.tolist())
    data_list = []
    for gi in graph_indexes:
        node_indexes = torch.arange(data.x.size(0), device=data.x.device)
        node_indexes = node_indexes[data.batch == gi]
        node_index_max, node_index_min = torch.max(node_indexes), torch.min(node_indexes)
        edge_indexes = (data.edge_index[0]>=node_index_min)&(data.edge_index[0]<=node_index_max)
        edge_indexes = torch.arange(data.edge_index.size(1), device=data.x.device)[edge_indexes]

        x = data.x[node_indexes]
        edge_index = data.edge_index[:,edge_indexes]-node_index_min
        edge_attr = data.edge_attr[edge_indexes]
        y = data.y[gi:gi+1]
        batch = torch.zeros_like(node_indexes)

        data_list.append(Batch(
            x=x, y=y, batch=batch,
            edge_index=edge_index, edge_attr=edge_attr,
        ))
    assert(data.x.size(0)==sum([d.x.size(0) for d in data_list]))
    assert(all([d.x.size(0) == d.batch.size(0) for d in data_list]))
    assert(data.edge_index.size(1)==sum([d.edge_index.size(1) for d in data_list]))
    assert(all([d.edge_index.size(1)==d.edge_attr.size(0) for d in data_list]))
    assert(all([data.x.size(1) == d.x.size(1) for d in data_list]))
    assert(all([data.edge_attr.size(1) == d.edge_attr.size(1) for d in data_list]))
    return data_list
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        if self.encode_edge:
            x = self.atom_encoder(x)
            x = self.conv1(x, edge_index, data.edge_attr)
        else:
            x = self.conv1(x, edge_index)
        x = F.relu(x)
        xs = [global_mean_pool(x, batch)]
        for i, conv in enumerate(self.convs):
            x = F.relu(conv(x, edge_index))
            xs += [global_mean_pool(x, batch)]
            if self.pooling_type != 'none':
                if self.pooling_type == 'complement':
                    complement = batched_negative_edges(edge_index=edge_index, batch=batch, force_undirected=True)
                    cluster = graclus(complement, num_nodes=x.size(0))
                elif self.pooling_type == 'graclus':
                    cluster = graclus(edge_index, num_nodes=x.size(0))
                data = Batch(x=x, edge_index=edge_index, batch=batch)
                data = max_pool(cluster, data)
                x, edge_index, batch = data.x, data.edge_index, data.batch

        if not self.no_cat:
            x = self.jump(xs)
        else:
            x = global_mean_pool(x, batch)
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return x
Exemplo n.º 7
0
 def forward(self, data):
     kwargs = {k: v for k, v in data.__dict__.items()}
     for _ in range(self.layer_num):
         data = Batch(**kwargs)
         kwargs['x'] = self.gnn_layer_module(data)
         assert (not torch.sum(torch.isnan(kwargs['x'])))
     return kwargs['x'], self.layer_num
Exemplo n.º 8
0
 def forward(self, data):
     kwargs = {k: v for k, v in data.__dict__.items()}
     for l in self.layers:
         data = Batch(**kwargs)
         kwargs['x'] = l(data)
         assert (not torch.sum(torch.isnan(kwargs['x'])))
     return kwargs['x'], len(self.layers)
Exemplo n.º 9
0
def convert_data_to_batch(x):
    data_list = []
    for xx in x:
        data_list.append(Data(pos=xx))

    batch = Batch()
    return batch.from_data_list(data_list)
Exemplo n.º 10
0
def x_pos_batch_to_pair_biggraph_pair(cloud_s_all, cloud_t_all, lss, lst):
    x_pos_s_all, x_pos_t_all, batch_s, batch_t = [], [], [], []
    for i, (ls, lt, cloud_s, cloud_t) in enumerate(zip(lss, lst, cloud_s_all, cloud_t_all)):
        x_pos_s_all += [cloud_s[:ls, :]]
        x_pos_t_all += [cloud_t[:lt, :]]
        batch_s += [torch.ones(ls,).long().unsqueeze(1).to(lss.device) * i]
        batch_t += [torch.ones(lt,).long().unsqueeze(1).to(lst.device) * i]

    x_pos_s_all = torch.cat(x_pos_s_all, dim=0)
    x_pos_t_all = torch.cat(x_pos_t_all, dim=0)
    batch_s = torch.cat(batch_s, dim=0).squeeze()
    batch_t = torch.cat(batch_t, dim=0).squeeze()

    graph_s = Batch(x=x_pos_s_all[:, 2:3], pos=x_pos_s_all[:, :3], batch=batch_s)
    graph_t = Batch(x=x_pos_t_all[:, 2:3], pos=x_pos_t_all[:, :3], batch=batch_t)
    return graph_s, graph_t
Exemplo n.º 11
0
  def _prepare_batch(self, batch):
    """Creates batch data for MEGNet model

    Note
    ----
    Ideally, we should only override default_generator method. But the problem
    here is that we _prepare_batch of TorchModel only supports non-graph
    data types. Hence, we are overriding it here. This should be fixed
    some time in the future.
    """
    try:
      from torch_geometric.data import Batch
    except ModuleNotFoundError:
      raise ImportError("This module requires PyTorch Geometric")

    # We convert deepchem.feat.GraphData to a PyG graph and then
    # batch it.
    graphs, labels, weights = batch
    # The default_generator method returns an array of dc.feat.GraphData objects
    # nested inside a list. To access the nested array of graphs, we are
    # indexing by 0 here.
    graph_list = [graph.to_pyg_graph() for graph in graphs[0]]
    pyg_batch = Batch()
    pyg_batch = pyg_batch.from_data_list(graph_list)

    _, labels, weights = super(MEGNetModel, self)._prepare_batch(
        ([], labels, weights))

    return pyg_batch, labels, weights
Exemplo n.º 12
0
def tg_transform(args, X):
    batch_size = X.size(0)

    pos = X[:, :, :2]

    x1 = pos.repeat(1, 1, args.num_hits).reshape(batch_size, args.num_hits * args.num_hits, 2)
    x2 = pos.repeat(1, args.num_hits, 1)

    diff_norms = torch.norm(x2 - x1 + 1e-12, dim=2)

    norms = diff_norms.reshape(batch_size, args.num_hits, args.num_hits)
    neighborhood = torch.nonzero(norms < args.cutoff, as_tuple=False)

    neighborhood = neighborhood[neighborhood[:, 1] != neighborhood[:, 2]]  # remove self-loops
    unique, counts = torch.unique(neighborhood[:, 0], return_counts=True)
    edge_index = (neighborhood[:, 1:] + (neighborhood[:, 0] * args.num_hits).view(-1, 1)).transpose(0, 1)

    x = X[:, :, 2].reshape(batch_size * args.num_hits, 1) + 0.5
    pos = 28 * pos.reshape(batch_size * args.num_hits, 2) + 14

    row, col = edge_index
    edge_attr = (pos[col] - pos[row]) / (2 * 28 * args.cutoff) + 0.5

    zeros = torch.zeros(batch_size * args.num_hits, dtype=int).to(args.device)
    zeros[torch.arange(batch_size) * args.num_hits] = 1
    batch = torch.cumsum(zeros, 0) - 1

    return Batch(batch=batch, x=x, edge_index=edge_index.contiguous(), edge_attr=edge_attr, y=None, pos=pos)
Exemplo n.º 13
0
def from_data_list_token(data_list, follow_batch=[]):
    """ This is pretty a copy paste of the from data list of pytorch geometric
    batch object with the difference that indexes that are negative are not incremented
    """

    keys = [set(data.keys) for data in data_list]
    keys = list(set.union(*keys))
    assert "batch" not in keys

    batch = Batch()
    batch.__data_class__ = data_list[0].__class__
    batch.__slices__ = {key: [0] for key in keys}

    for key in keys:
        batch[key] = []

    for key in follow_batch:
        batch["{}_batch".format(key)] = []

    cumsum = {key: 0 for key in keys}
    batch.batch = []
    for i, data in enumerate(data_list):
        for key in data.keys:
            item = data[key]
            if torch.is_tensor(item) and item.dtype != torch.bool:
                mask = item >= 0
                item[mask] = item[mask] + cumsum[key]
            if torch.is_tensor(item):
                size = item.size(data.__cat_dim__(key, data[key]))
            else:
                size = 1
            batch.__slices__[key].append(size + batch.__slices__[key][-1])
            cumsum[key] += data.__inc__(key, item)
            batch[key].append(item)

            if key in follow_batch:
                item = torch.full((size,), i, dtype=torch.long)
                batch["{}_batch".format(key)].append(item)

        num_nodes = data.num_nodes
        if num_nodes is not None:
            item = torch.full((num_nodes,), i, dtype=torch.long)
            batch.batch.append(item)

    if num_nodes is None:
        batch.batch = None

    for key in batch.keys:
        item = batch[key][0]
        if torch.is_tensor(item):
            batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item))
        elif isinstance(item, int) or isinstance(item, float):
            batch[key] = torch.tensor(batch[key])
        else:
            raise ValueError("Unsupported attribute type {} : {}".format(type(item), item))

    if torch_geometric.is_debug_enabled():
        batch.debug()

    return batch.contiguous()
Exemplo n.º 14
0
def sample_batch_pyg(data, sample_config):
    """
    Perturb the structure and node attributes.

    Parameters
    ----------
    data: torch_geometric.data.Batch
        Dataset containing the attributes, edge indices, and batch-ID
    sample_config: dict
        Configuration specifying the sampling probabilities

    Returns
    -------
    per_data: torch_geometric.Dataset
        Dataset containing the perturbed graphs
    """
    pf_plus_adj = sample_config.get('pf_plus_adj', 0)
    pf_plus_att = sample_config.get('pf_plus_att', 0)

    pf_minus_adj = sample_config.get('pf_minus_adj', 0)
    pf_minus_att = sample_config.get('pf_minus_att', 0)

    per_x = binary_perturb(data.x, pf_minus_att, pf_plus_att)

    per_edge_index = sparse_perturb_adj_batch(data_idx=data.edge_index,
                                              nnodes=torch.bincount(
                                                  data.batch),
                                              pf_minus=pf_minus_adj,
                                              pf_plus=pf_plus_adj,
                                              undirected=True)

    per_data = Batch(batch=data.batch, x=per_x, edge_index=per_edge_index)

    return per_data
Exemplo n.º 15
0
def avg_pool(cluster, data, transform=None):
    r"""Pools and coarsens a graph given by the
    :class:`torch_geometric.data.Data` object according to the clustering
    defined in :attr:`cluster`.
    Final node features are defined by the *average* features of all nodes
    within the same cluster.
    See :meth:`torch_geometric.nn.pool.max_pool` for more details.

    Args:
        cluster (LongTensor): Cluster vector :math:`\mathbf{c} \in \{ 0,
            \ldots, N - 1 \}^N`, which assigns each node to a specific cluster.
        data (Data): Graph data object.
        transform (callable, optional): A function/transform that takes in the
            coarsened and pooled :obj:`torch_geometric.data.Data` object and
            returns a transformed version. (default: :obj:`None`)

    :rtype: :class:`torch_geometric.data.Data`
    """
    cluster, perm = consecutive_cluster(cluster)

    x = None if data.x is None else _avg_pool_x(cluster, data.x)
    index, attr = pool_edge(cluster, data.edge_index, data.edge_attr)
    batch = None if data.batch is None else pool_batch(perm, data.batch)
    pos = None if data.pos is None else pool_pos(cluster, data.pos)

    data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos)

    if transform is not None:
        data = transform(data)

    return data
Exemplo n.º 16
0
 def forward(self, data):
     subgraph_data = subgraph_loader(data, k, super_node_size, num_tours,
                                     num_cpus)
     subgraphs = [
         get_subgraph(data[subgraph_data.batch[i].item()],
                      subgraph_data.subgraphs[i].squeeze())
         for i in range(len(subgraph_data.subgraphs))
     ]
     subgraphs_lst = []
     for i in range(0, len(subgraphs), 500):
         subgraphs_b = Batch().from_data_list(
             subgraphs[i:i + min([500, len(subgraphs) - i])])
         subgraphs_b = self.gnn_layer(subgraphs_b.x.cuda(), subgraphs_b.edge_index.cuda(), subgraphs_b.batch.cuda()) \
         if next(self.parameters()).get_device() != -1 else self.gnn_layer(subgraphs_b.x, subgraphs_b.edge_index, subgraphs_b.batch)
         subgraphs_lst.append(subgraphs_b)
     subgraphs = torch.cat(subgraphs_lst, dim=0)
     subgraphs = self.output_layer(subgraphs)
     weights = subgraph_data.weights.cuda() if next(
         self.parameters()).get_device() != -1 else subgraph_data.weights
     batch = subgraph_data.batch.cuda() if next(
         self.parameters()).get_device() != -1 else subgraph_data.batch
     subgraphs = subgraphs * weights
     norm = global_add_pool(weights, batch)
     energy = global_add_pool(subgraphs, batch)
     return energy / norm
Exemplo n.º 17
0
    def forward(self, data):
        batch_obj = Batch()
        x, pos, batch = data.x, data.pos, data.batch
        if self._precompute_multi_scale:
            idx = getattr(data, "idx_{}".format(self._index), None)
        else:
            idx = self.sampler(pos, batch)
            batch_obj.idx = idx

        ms_x = []
        for scale_idx in range(self.neighbour_finder.num_scales):
            if self._precompute_multi_scale:
                edge_index = getattr(
                    data, "edge_index_{}_{}".format(self._index, scale_idx),
                    None)
            else:
                row, col = self.neighbour_finder(
                    pos,
                    pos[idx],
                    batch_x=batch,
                    batch_y=batch[idx],
                    scale_idx=scale_idx,
                )
                edge_index = torch.stack([col, row], dim=0)

            ms_x.append(self.conv(x, (pos, pos[idx]), edge_index, batch))

        batch_obj.x = torch.cat(ms_x, -1)
        batch_obj.pos = pos[idx]
        batch_obj.batch = batch[idx]
        copy_from_to(data, batch_obj)
        return batch_obj
Exemplo n.º 18
0
    def forward(self, data):
        batch_obj = Batch()
        data, data_skip = data
        x, pos, batch = data.x, data.pos, data.batch
        x_skip, pos_skip, batch_skip = data_skip.x, data_skip.pos, data_skip.batch

        if self.neighbour_finder is not None:
            if self._precompute_multi_scale:  # TODO For now, it uses the one calculated during down steps
                edge_index = getattr(data_skip,
                                     "edge_index_{}".format(self._index), None)
                col, row = edge_index
                edge_index = torch.stack([row, col], dim=0)
            else:
                row, col = self.neighbour_finder(pos, pos_skip, batch,
                                                 batch_skip)
                edge_index = torch.stack([col, row], dim=0)
        else:
            edge_index = None

        x = self.conv(x, pos, pos_skip, batch, batch_skip, edge_index)

        if x_skip is not None and self._skip:
            x = torch.cat([x, x_skip], dim=1)

        if hasattr(self, "nn"):
            batch_obj.x = self.nn(x)
        else:
            batch_obj.x = x
        copy_from_to(data_skip, batch_obj)
        return batch_obj
Exemplo n.º 19
0
def buildGraph(feat, label):
    B = feat.shape[0]
    NoOfNodes = feat.shape[1]
    #feat.reshape(B,NoOfNodes,-1)
    edge_index = list(itertools.permutations(np.arange(0, NoOfNodes), 2))
    edge_index = torch.LongTensor(edge_index).T
    listofData = []
    for i in range(0, B):
        feat_arr = feat[i].detach().cpu().numpy().reshape(NoOfNodes, -1)
        edge_attr = np.asarray([
            np.linalg.norm(a - b)
            for a, b in itertools.product(feat_arr, feat_arr)
        ])
        # for a in feat_arr[i]:
        #     for b in feat_arr[i]:
        #         print(np.linalg.norm(a-b))
        edge_attr = edge_attr[edge_attr > 0]
        edge_attr = torch.Tensor(edge_attr).view(-1)
        data = Data(x=torch.Tensor(feat_arr),
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    y=label[i].view(-1))
        listofData.append(data)
    batch = Batch().from_data_list(listofData)

    return batch
Exemplo n.º 20
0
    def forward(self, data, *args, **kwargs):
        """
        Parameters:
        -----------
        data
            A SparseTensor that contains the data itself and its metadata information. Should contain
                F -- Features [N, C]
                coords -- Coords [N, 4]

        Returns
        --------
        data:
            - x [1, output_nc]

        """
        self._set_input(data)
        data = self.input
        for i in range(len(self.down_modules)):
            data = self.down_modules[i](data)

        out = Batch(x=data.F, batch=data.C[:, 0].long().to(data.F.device))
        if not isinstance(self.inner_modules[0], Identity):
            out = self.inner_modules[0](out)

        if self.has_mlp_head:
            out.x = self.mlp(out.x)
        return out
Exemplo n.º 21
0
def mols_to_pyg_batch(mols, idm=False, ratio=2., device=None):
    if not isinstance(mols, list):
        mols = [mols]
    graphs = [mol_to_pyg_graph(mol, idm, ratio) for mol in mols]

    g1 = Batch().from_data_list([graph[0] for graph in graphs])
    if device is not None:
        g1 = g1.to(device)

    if idm:
        g2 = Batch().from_data_list([graph[1] for graph in graphs]).to(device)
        if device is not None:
            g2 = g2.to(device)
    else:
        g2 = None
    return [g1, g2]
    def forward(self, data):
        batch_obj = Batch()
        x, pos, batch = data.x, data.pos, data.batch
        idx_sampler = self.sampler(pos=pos, x=x, batch=batch)

        idx_neighbour, _ = self.neighbour_finder(pos,
                                                 pos,
                                                 batch_x=batch,
                                                 batch_y=batch)

        shadow_x = torch.full((1, ) + x.shape[1:],
                              self.shadow_features_fill).to(x.device)
        shadow_points = torch.full((1, ) + pos.shape[1:],
                                   self.shadow_points_fill_).to(x.device)

        x = torch.cat([x, shadow_x], dim=0)
        pos = torch.cat([pos, shadow_points], dim=0)

        x_neighbour = x[idx_neighbour]
        pos_centered_neighbour = pos[idx_neighbour] - pos[:-1].unsqueeze(
            1)  # Centered the points

        batch_obj.x = self.conv(x, pos, x_neighbour, pos_centered_neighbour,
                                idx_neighbour, idx_sampler)

        batch_obj.pos = pos[idx_sampler]
        batch_obj.batch = batch[idx_sampler]
        copy_from_to(data, batch_obj)
        return batch_obj
Exemplo n.º 23
0
def test_graphnet_for_graphs_in_batch():
    # Testing with a batch of Graphs
    try:
        from torch_geometric.data import Batch
    except ModuleNotFoundError:
        raise ImportError("Tests require pytorch geometric to be installed")

    n_node_features, n_edge_features, n_global_features = 3, 4, 5
    fgg = FakeGraphGenerator(min_nodes=8,
                             max_nodes=12,
                             n_node_features=n_node_features,
                             avg_degree=10,
                             n_edge_features=n_edge_features,
                             n_classes=2,
                             task='graph',
                             z=n_global_features)
    graphs = fgg.sample(n_graphs=10)

    graphnet = GraphNetwork(n_node_features, n_edge_features,
                            n_global_features)

    graph_batch = Batch()
    graph_batch = graph_batch.from_data_list(
        [graph.to_pyg_graph() for graph in graphs.X])

    new_node_features, new_edge_features, new_global_features = graphnet(
        graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr,
        graph_batch.z, graph_batch.batch)
    assert graph_batch.x.size() == new_node_features.size()
    assert graph_batch.edge_attr.size() == new_edge_features.size()
    assert graph_batch.z.size() == new_global_features.size()
Exemplo n.º 24
0
def test_single_voxel_grid():
    pos = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]])
    edge_index = torch.tensor([[0, 0, 3], [1, 2, 4]])
    batch = torch.tensor([0, 0, 0, 1, 1])
    x = torch.randn(5, 16)

    cluster = voxel_grid(pos, size=5, batch=batch)
    assert cluster.tolist() == [0, 0, 0, 1, 1]

    data = Batch(x=x, edge_index=edge_index, pos=pos, batch=batch)
    data = avg_pool(cluster, data)

    cluster_no_batch = voxel_grid(pos, size=5)
    assert cluster_no_batch.tolist() == [0, 0, 0, 0, 0]

    data_no_batch = Batch(x=x, edge_index=edge_index, pos=pos)
    data_no_batch = avg_pool(cluster_no_batch, data_no_batch)
Exemplo n.º 25
0
def get_final_reward(state, env, surrogate_model, device):
    # g = state_to_graph(state, env, keep_self_edges=False)
    g = Batch().from_data_list([mol_to_pyg_graph(state)])
    g = g.to(device)
    with torch.autograd.no_grad():
        pred_docking_score = surrogate_model(g, None)
    reward = pred_docking_score.item() * -1
    return reward
Exemplo n.º 26
0
def tg_transform(args, X):
    batch_size = X.size(0)

    pos = X[:, :, :2]

    x1 = pos.repeat(1, 1,
                    args.num_hits).reshape(batch_size,
                                           args.num_hits * args.num_hits, 2)
    x2 = pos.repeat(1, args.num_hits, 1)

    diff_norms = torch.norm(x2 - x1 + 1e-12, dim=2)

    # diff = x2-x1
    # diff = diff[diff_norms < args.cutoff]

    norms = diff_norms.reshape(batch_size, args.num_hits, args.num_hits)
    neighborhood = torch.nonzero(norms < args.cutoff, as_tuple=False)
    # diff = diff[neighborhood[:, 1] != neighborhood[:, 2]]

    neighborhood = neighborhood[neighborhood[:, 1] !=
                                neighborhood[:, 2]]  # remove self-loops
    unique, counts = torch.unique(neighborhood[:, 0], return_counts=True)
    # edge_slices = torch.cat((torch.tensor([0]).to(device), counts.cumsum(0)))
    edge_index = (neighborhood[:, 1:] +
                  (neighborhood[:, 0] * args.num_hits).view(-1, 1)).transpose(
                      0, 1)

    # normalizing edge attributes
    # edge_attr_list = list()
    # for i in range(batch_size):
    #     start_index = edge_slices[i]
    #     end_index = edge_slices[i + 1]
    #     temp = diff[start_index:end_index]
    #     max = torch.max(temp)
    #     temp = temp/(2 * max + 1e-12) + 0.5
    #     edge_attr_list.append(temp)
    #
    # edge_attr = torch.cat(edge_attr_list)

    # edge_attr = diff/(2 * args.cutoff) + 0.5

    x = X[:, :, 2].reshape(batch_size * args.num_hits, 1) + 0.5
    pos = 28 * pos.reshape(batch_size * args.num_hits, 2) + 14

    row, col = edge_index
    edge_attr = (pos[col] - pos[row]) / (2 * 28 * args.cutoff) + 0.5

    zeros = torch.zeros(batch_size * args.num_hits, dtype=int).to(args.device)
    zeros[torch.arange(batch_size) * args.num_hits] = 1
    batch = torch.cumsum(zeros, 0) - 1

    return Batch(batch=batch,
                 x=x,
                 edge_index=edge_index.contiguous(),
                 edge_attr=edge_attr,
                 y=None,
                 pos=pos)
Exemplo n.º 27
0
 def embedding(self, subgraphs):
     with torch.no_grad():
         subgraphs_lst = []
         for i in range(0, len(subgraphs), 500):
             subgraphs_b =  Batch().from_data_list(subgraphs[i:i+min([500,len(subgraphs)-i])])
             subgraphs_b = self.gnn_layer(subgraphs_b.x.cuda(), subgraphs_b.edge_index.cuda(), subgraphs_b.batch.cuda()) \
             if next(self.parameters()).get_device() != -1 else self.gnn_layer(subgraphs_b.x, subgraphs_b.edge_index, subgraphs_b.batch)
             subgraphs_lst.append(subgraphs_b)
         subgraphs = torch.cat(subgraphs_lst,dim=0)
         return subgraphs
 def forward(self, data, **kwargs):
     batch_obj = Batch()
     x, pos, batch = data.x, data.pos, data.batch
     x = self.nn(torch.cat([x, pos], dim=1))
     x = self.pool(x, batch)
     batch_obj.x = x
     batch_obj.pos = pos.new_zeros((x.size(0), 3))
     batch_obj.batch = torch.arange(x.size(0), device=batch.device)
     copy_from_to(data, batch_obj)
     return batch_obj
Exemplo n.º 29
0
def convert_to_batch(args, data, batch_size):
    zeros = torch.zeros(batch_size * args.num_hits, dtype=int).to(args.device)
    zeros[torch.arange(batch_size) * args.num_hits] = 1
    batch = torch.cumsum(zeros, 0) - 1

    return Batch(batch=batch,
                 x=data.x,
                 pos=data.pos,
                 edge_index=data.edge_index,
                 edge_attr=data.edge_attr)
Exemplo n.º 30
0
    def forward(self,
                data,
                output_node_feat_flag=False,
                output_layer_num_flag=False,
                output_residual_confidence_flag=False):
        kwargs = {k: v for k, v in data.__dict__.items()}
        kwargs['input_x'] = x = kwargs['x']
        kwargs.pop('x')

        x = self.embedding_module(x)
        layer_num = 0
        x_list = []
        residual_confidence_list = []
        for gnn_module in self.gnn_module_list:
            if 'ACT' in gnn_module.__class__.__name__:
                x, cur_layer_num, cur_residual_confidence = gnn_module(
                    Batch(x=x, **kwargs))
                residual_confidence_list.append(cur_residual_confidence)
            else:
                x, cur_layer_num = gnn_module(Batch(x=x, **kwargs))
            layer_num += cur_layer_num
            x_list.append(x)
        if len(residual_confidence_list):
            residual_confidence = torch.sum(torch.stack(
                residual_confidence_list, dim=0),
                                            dim=0)
        if self.pointwise_head_layer_flag:
            x_list = [self.head_module(x) for x in x_list]
        global_feat = self.readout([Batch(x=x, **kwargs) for x in x_list])
        # To avoid information-leak between nodes, we perform pointwise head-module for the physical simulation task
        if not self.pointwise_head_layer_flag:
            out = self.head_module(global_feat)
        else:
            out = global_feat

        output = (out, )
        if output_node_feat_flag:
            output = output + (x, )
        if output_layer_num_flag:
            output = output + (layer_num, )
        if output_residual_confidence_flag:
            output = output + (residual_confidence, )
        return output