Пример #1
0
    def compute_grad_pen(self,
                         expert_state,
                         expert_action,
                         policy_state,
                         policy_action,
                         lambda_=10):
        # merge graphs, apply alpha to vote shares
        mixup_state = Batch()
        for key, value in expert_state:
            assert isinstance(key, str), str(key)
            if key in ("edge_index", "edge_attr"):
                continue
            mixup_state[key] = torch.cat(
                [expert_state[key], policy_state[key]])
        mixup_state.edge_index = torch.cat(
            [
                expert_state.edge_index,
                policy_state.edge_index + expert_state.batch.shape[0],
            ],
            dim=1,
        )

        alpha = torch.rand(expert_action.size(0))
        batch_size = expert_state.batch.max().item() + 1
        mixup_votes = []
        for i in range(batch_size):
            _em = expert_state.batch == i
            _pm = policy_state.batch == i
            votes = torch.zeros((_em.sum() + _pm.sum()).item())
            assert votes.shape[0]
            votes[expert_action[i]] = alpha[i]
            votes[policy_action[i] + _em.sum().item()] = 1 - alpha[i]
            mixup_votes.append(votes)
        mixup_action = torch.cat(mixup_votes).view(-1, 1)
        mixup_action.requires_grad = True

        disc = self.forward(mixup_state, mixup_action)
        ones = torch.ones(disc.size()).to(disc.device)
        inputs = [mixup_action]
        for key, value in mixup_state:
            if value.dtype == torch.float:
                value.requires_grad = True
                inputs.append(value)
        grad = autograd.grad(
            outputs=disc,
            inputs=inputs,
            grad_outputs=ones,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
            allow_unused=True,
        )[0]

        grad_pen = lambda_ * (grad.norm(2, dim=1) - 1).pow(2).mean()
        return grad_pen
Пример #2
0
    def forward(self, data, **kwargs):
        batch_obj = Batch()
        x, pos, batch = data.x, data.pos, data.batch
        idx = self.sampler(pos, batch)
        row, col = self.neighbour_finder(pos, pos[idx], batch_x=batch, batch_y=batch[idx])
        edge_index = torch.stack([col, row], dim=0)
        batch_obj.idx = idx
        batch_obj.edge_index = edge_index

        batch_obj.x = self.conv(x, (pos[idx], pos), edge_index, batch)

        batch_obj.pos = pos[idx]
        batch_obj.batch = batch[idx]
        copy_from_to(data, batch_obj)
        return batch_obj
Пример #3
0
    def __collate__(self, n_id):
        n_id = torch.tensor(n_id)

        rowptr, col, value = self.adj_t.csr()
        out = torch.ops.torch_sparse.ego_k_hop_sample_adj(
            rowptr, col, n_id, self.depth, self.num_neighbors, self.replace)
        rowptr, col, n_id, e_id, ptr, root_n_id = out

        adj_t = SparseTensor(rowptr=rowptr,
                             col=col,
                             value=value[e_id] if value is not None else None,
                             sparse_sizes=(n_id.numel(), n_id.numel()),
                             is_sorted=True,
                             trust_data=True)

        batch = Batch(batch=torch.ops.torch_sparse.ptr2ind(ptr, n_id.numel()),
                      ptr=ptr)
        batch.root_n_id = root_n_id

        if self.is_sparse_tensor:
            batch.adj_t = adj_t
        else:
            row, col, e_id = adj_t.t().coo()
            batch.edge_index = torch.stack([row, col], dim=0)

        for k, v in self.data:
            if k in ['edge_index', 'adj_t', 'num_nodes']:
                continue
            if k == 'y' and v.size(0) == self.data.num_nodes:
                batch[k] = v[n_id][root_n_id]
            elif isinstance(v, Tensor) and v.size(0) == self.data.num_nodes:
                batch[k] = v[n_id]
            elif isinstance(v, Tensor) and v.size(0) == self.data.num_edges:
                batch[k] = v[e_id]
            else:
                batch[k] = v

        return batch
Пример #4
0
    def forward(self, data):
        batch_obj = Batch()
        x, pos, batch = data.x, data.pos, data.batch
        if self._precompute_multi_scale:
            idx = getattr(data, "index_{}".format(self._index), None)
            edge_index = getattr(data, "edge_index_{}".format(self._index),
                                 None)
        else:
            idx = self.sampler(pos, batch)
            row, col = self.neighbour_finder(pos,
                                             pos[idx],
                                             batch_x=batch,
                                             batch_y=batch[idx])
            edge_index = torch.stack([col, row], dim=0)
            batch_obj.idx = idx
            batch_obj.edge_index = edge_index

        batch_obj.x = self.conv(x, (pos, pos[idx]), edge_index, batch)

        batch_obj.pos = pos[idx]
        batch_obj.batch = batch[idx]
        copy_from_to(data, batch_obj)
        return batch_obj
Пример #5
0
def data_loader(database, latent_graph_dict, latent_graph_list, batch_size=4, shuffle=True):
    # -- This loader creates Databatches from a Database, run on either the training or validation set!

    # Establish empty lists
    graph_list_p1 = []
    graph_list_p2 = []
    graph_list_pm = []
    target_list = []
    Temperature_list = []

    # Iterate through database and construct latent graphs of the molecules, as well as the target tensors
    for itm in range(len(database)):
        graph_p1 = latent_graph_list[latent_graph_dict[database["Molecule1"].loc[itm]]]
        graph_p2 = latent_graph_list[latent_graph_dict[database["Molecule2"].loc[itm]]]
        graph_pm = latent_graph_list[latent_graph_dict[database["Membrane"].loc[itm]]]

        graph_p1.y = torch.tensor([float(database["molp1"].loc[itm])])
        graph_p2.y = torch.tensor([float(database["molp2"].loc[itm])])
        graph_pm.y = torch.tensor([float(database["molpm"].loc[itm])])

        graph_list_p1.append(graph_p1)
        graph_list_p2.append(graph_p2)
        graph_list_pm.append(graph_pm)

        target = torch.tensor([database[str(i)].loc[itm]*10 for i in range(2, 17)], dtype=torch.float)
        target_list.append(target)

        Temperature_list.append(database["Temperature"].loc[itm])

    # Generate a shuffled integer list that then produces distinct batches
    idxs = torch.randperm(len(database)).tolist()
    if shuffle==False:
        idxs = sorted(idxs)

    # Generate the batches
    batch_p1 = []
    batch_p2 = []
    batch_pm = []
    batch_target = []
    batch_Temperature = []

    for b in range(len(database))[::batch_size]:
        # Creating empty Batch objects
        B_p1 = Batch()
        B_p2 = Batch()
        B_pm = Batch()

        # Creating empty lists that will be concatenated later (advanced minibatching)
        stack_x_p1, stack_x_p2, stack_x_pm = [], [], []
        stack_ei_p1, stack_ei_p2, stack_ei_pm = [], [], []
        stack_ea_p1, stack_ea_p2, stack_ea_pm = [], [], []
        stack_y_p1, stack_y_p2, stack_y_pm = [], [], []
        stack_btc_p1, stack_btc_p2, stack_btc_pm = [], [], []

        stack_target = []
        stack_Temperature = []

        # Creating the batch shifts, that shift the edge indices
        b_shift_p1 = 0
        b_shift_p2 = 0
        b_shift_pm = 0

        for i in range(batch_size):
            if (b+i) < len(database):
                stack_x_p1.append(graph_list_p1[idxs[b + i]].x)
                stack_x_p2.append(graph_list_p2[idxs[b + i]].x)
                stack_x_pm.append(graph_list_pm[idxs[b + i]].x)

                stack_ei_p1.append(graph_list_p1[idxs[b + i]].edge_index+b_shift_p1)
                stack_ei_p2.append(graph_list_p2[idxs[b + i]].edge_index+b_shift_p2)
                stack_ei_pm.append(graph_list_pm[idxs[b + i]].edge_index+b_shift_pm)

                # Updating the shifts
                b_shift_p1 += graph_list_p1[idxs[b + i]].x.size()[0]
                b_shift_p2 += graph_list_p2[idxs[b + i]].x.size()[0]
                b_shift_pm += graph_list_pm[idxs[b + i]].x.size()[0]

                stack_ea_p1.append(graph_list_p1[idxs[b + i]].edge_attr)
                stack_ea_p2.append(graph_list_p2[idxs[b + i]].edge_attr)
                stack_ea_pm.append(graph_list_pm[idxs[b + i]].edge_attr)

                stack_y_p1.append(graph_list_p1[idxs[b + i]].y)
                stack_y_p2.append(graph_list_p2[idxs[b + i]].y)
                stack_y_pm.append(graph_list_pm[idxs[b + i]].y)

                stack_btc_p1.append(
                    torch.full([graph_list_p1[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long))
                # FILL VALUE IS PROBABLY JUST i! (OR NOT)
                stack_btc_p2.append(
                    torch.full([graph_list_p2[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long))
                stack_btc_pm.append(
                    torch.full([graph_list_pm[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long))

                stack_target.append(target_list[idxs[b + i]])
                stack_Temperature.append(Temperature_list[int(b+i)])

        # print(stack_y_p1)
        # print(stack_x_p1)

        B_p1.edge_index=torch.cat(stack_ei_p1,dim=1)
        B_p1.x=torch.cat(stack_x_p1,dim=0)
        B_p1.edge_attr=torch.cat(stack_ea_p1,dim=0)
        B_p1.y=torch.cat(stack_y_p1,dim=0)
        B_p1.batch=torch.cat(stack_btc_p1,dim=0)

        B_p2.edge_index = torch.cat(stack_ei_p2, dim=1)
        B_p2.x = torch.cat(stack_x_p2, dim=0)
        B_p2.edge_attr = torch.cat(stack_ea_p2, dim=0)
        B_p2.y = torch.cat(stack_y_p2, dim=0)
        B_p2.batch = torch.cat(stack_btc_p2, dim=0)

        B_pm.edge_index = torch.cat(stack_ei_pm, dim=1)
        B_pm.x = torch.cat(stack_x_pm, dim=0)
        B_pm.edge_attr = torch.cat(stack_ea_pm, dim=0)
        B_pm.y = torch.cat(stack_y_pm, dim=0)
        B_pm.batch = torch.cat(stack_btc_pm, dim=0)

        B_target = torch.stack(stack_target, dim=0)

        B_Temperature = torch.Tensor(stack_Temperature)

        # Appending batches
        batch_p1.append(B_p1)
        batch_p2.append(B_p2)
        batch_pm.append(B_pm)
        batch_target.append(B_target)
        batch_Temperature.append(B_Temperature)

    # Return
    return [batch_p1, batch_p2, batch_pm, batch_target, batch_Temperature]
Пример #6
0
    def forward(self, plist):
        # Here we instantiate the three molecular graphs for the first level GNN
        p1, p2, pm, Temperature = plist
        x_p1, ei_p1, ea_p1, u_p1, btc_p1 = p1.x, p1.edge_index, p1.edge_attr, p1.y, p1.batch
        x_p2, ei_p2, ea_p2, u_p2, btc_p2 = p2.x, p2.edge_index, p2.edge_attr, p2.y, p2.batch
        x_pm, ei_pm, ea_pm, u_pm, btc_pm = pm.x, pm.edge_index, pm.edge_attr, pm.y, pm.batch

        # Embed the node and edge features
        enc_x_p1 = self.encoding_node_1(x_p1)
        enc_x_p2 = self.encoding_node_1(x_p2)
        enc_x_pm = self.encoding_node_1(x_pm)

        enc_ea_p1 = self.encoding_edge_1(ea_p1)
        enc_ea_p2 = self.encoding_edge_1(ea_p2)
        enc_ea_pm = self.encoding_edge_1(ea_pm)

        #Create the empty molecular graphs for feature extraction, graph level one
        u1 = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE),
                        fill_value=0.1,
                        dtype=torch.float).to(device)
        u2 = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE),
                        fill_value=0.1,
                        dtype=torch.float).to(device)
        um = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE),
                        fill_value=0.1,
                        dtype=torch.float).to(device)

        # Now the first level rounds of message passing are performed, molecular features are constructed
        for _ in range(self.no_mp_one):
            enc_x_p1, enc_ea_p1, u1 = self.meta1(x=enc_x_p1,
                                                 edge_index=ei_p1,
                                                 edge_attr=enc_ea_p1,
                                                 u=u1,
                                                 batch=btc_p1)
            enc_x_p2, enc_ea_p2, u2 = self.meta1(x=enc_x_p2,
                                                 edge_index=ei_p2,
                                                 edge_attr=enc_ea_p2,
                                                 u=u2,
                                                 batch=btc_p2)
            enc_x_pm, enc_ea_pm, um = self.meta1(x=enc_x_pm,
                                                 edge_index=ei_pm,
                                                 edge_attr=enc_ea_pm,
                                                 u=um,
                                                 batch=btc_pm)

        # --- GRAPH GENERATION SECOND LEVEL

        #Encode the nodes second level
        u1 = self.encoding_node_2(u1)
        u2 = self.encoding_node_2(u2)
        um = self.encoding_node_2(um)

        # Instantiate new Batch object for second level graph
        nu_Batch = Batch()

        # Create edge indices for the second level (no connection between p1 and p2
        nu_ei = []
        temp_ei = torch.tensor([[0, 2, 1, 2], [2, 0, 2, 1]], dtype=torch.long)
        for b in range(self.batch_size):
            nu_ei.append(
                temp_ei + b * 3
            )  # +3 because this is the number of nodes in the second stage graph
        nu_Batch.edge_index = torch.cat(nu_ei, dim=1)

        # Create the edge features for the second level graphs from the first level "u"s
        p1_div_pm = u_p1 / u_pm
        p2_div_pm = u_p2 / u_pm
        #p1_div_p2 = u_p1 / (u_p2-1)

        # Concatenate the temperature and molecular percentages
        concat_T_p1pm = torch.transpose(torch.stack([Temperature, p1_div_pm]),
                                        0, 1)
        concat_T_p2pm = torch.transpose(torch.stack([Temperature, p2_div_pm]),
                                        0, 1)

        # Encode the edge features
        concat_T_p1pm = self.encoding_edge_2(concat_T_p1pm)
        concat_T_p2pm = self.encoding_edge_2(concat_T_p2pm)

        nu_ea = []
        for b in range(self.batch_size):
            temp_ea = []
            # Appending twice because of bidirectional edges
            temp_ea.append(concat_T_p1pm[b])
            temp_ea.append(concat_T_p1pm[b])
            temp_ea.append(concat_T_p2pm[b])
            temp_ea.append(concat_T_p2pm[b])
            nu_ea.append(torch.stack(temp_ea, dim=0))
        nu_Batch.edge_attr = torch.cat(nu_ea, dim=0)

        # Create new nodes in the batch
        nu_x = []
        for b in range(self.batch_size):
            temp_x = []
            temp_x.append(u1[b])
            temp_x.append(u2[b])
            temp_x.append(um[b])
            nu_x.append(torch.stack(temp_x, dim=0))
        nu_Batch.x = torch.cat(nu_x, dim=0)

        # Create new graph level target
        gamma = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_TWO),
                           fill_value=0.1,
                           dtype=torch.float)
        nu_Batch.y = gamma
        # Create new batch
        nu_btc = []
        for b in range(self.batch_size):
            nu_btc.append(
                torch.full(size=[3], fill_value=(int(b)), dtype=torch.long))
        nu_Batch.batch = torch.cat(nu_btc, dim=0)

        nu_Batch = nu_Batch.to(device)

        # --- MESSAGE PASSING LVL 2
        for _ in range(self.no_mp_two):
            nu_Batch.x, nu_Batch.edge_attr, nu_Batch.y = self.meta3(
                x=nu_Batch.x,
                edge_index=nu_Batch.edge_index,
                edge_attr=nu_Batch.edge_attr,
                u=nu_Batch.y,
                batch=nu_Batch.batch)

        c = self.mlp_last(nu_Batch.y)

        return c
Пример #7
0
    def forward(self, primal_graph_batch, dual_graph_batch, pooling_log):
        r"""Performs the unpooling operation.

        Args:
            primal_graph_batch (torch_geometric.data.batch.Batch): Batch
                containing the input primal graphs on which the unpooling
                operation should be applied.
            dual_graph_batch (torch_geometric.data.batch.Batch): Batch
                containing the input dual graphs on which the unpooling
                operation should be applied.
            pooling_log (pd_mesh_net.nn.pool.PoolingInfo): Data structure
                containing the information saved when pooling, and needed to
                perform the unpooling operation.

        Returns:
            new_primal_graph_batch (torch_geometric.data.batch.Batch): Output
                primal-graph batch after unpooling.
            new_dual_graph_batch (torch_geometric.data.batch.Batch): Output
                dual-graph batch after unpooling.
            new_primal_edge_to_dual_node_idx_batch (dict): Output dictionary
                representing the associations between primal-graph edges and
                dual-graph nodes in the batch after unpooling.
        """
        # Reconstruct the primal graph.
        old_primal_node_to_new_one = pooling_log.old_primal_node_to_new_one
        # - Reconstruct the connectivity.
        new_primal_graph_batch = Batch(batch=pooling_log.old_primal_graph_batch)
        new_primal_graph_batch.edge_index = pooling_log.old_primal_edge_index
        # - Assign to the new primal nodes the features of the primal nodes in
        #   which these were merged when performing the pooling operation.
        new_primal_graph_batch.x = primal_graph_batch.x[
            old_primal_node_to_new_one]

        # Reconstruct the dual graph.
        old_dual_node_to_new_one = pooling_log.old_dual_node_to_new_one
        assert (old_dual_node_to_new_one is not None), (
            "The input pooling log does not contain the mapping from dual "
            "nodes before pooling to dual nodes after pooling. Please set the "
            "argument `return_old_dual_node_to_new_dual_node` to True in the "
            "`DualPrimalEdgePooling` layer.")
        # - Reconstruct the connectivity.
        new_dual_graph_batch = Batch(batch=pooling_log.old_dual_graph_batch)
        new_dual_graph_batch.edge_index = pooling_log.old_dual_edge_index
        # - Assign the same learnable feature to all the dual nodes that do not
        #   have a corresponding dual node after performing the pooling
        #   operation.
        num_old_dual_nodes = len(old_dual_node_to_new_one)
        new_dual_graph_batch_x = self.new_dual_feature.repeat(
            num_old_dual_nodes, 1)
        # - Assign to the new dual nodes that have a corresponding dual node
        #   after the pooling operation the features of these corresponding
        #   nodes.
        new_dual_graph_batch_x[
            old_dual_node_to_new_one != -1] = dual_graph_batch.x[
                old_dual_node_to_new_one[old_dual_node_to_new_one != -1]]
        new_dual_graph_batch.x = new_dual_graph_batch_x

        # Remap the primal-graph edges to the dual-graph nodes.
        new_primal_edge_to_dual_node_idx_batch = (
            pooling_log.old_primal_edge_to_dual_node_index)

        return (new_primal_graph_batch, new_dual_graph_batch,
                new_primal_edge_to_dual_node_idx_batch)