Beispiel #1
0
class clique_MPNN(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden1, hidden2, deltas, elasticity=0.01, num_iterations = 30):
        super(cliqueMPNN_hindsight_earlyGAT, self).__init__()
        self.hidden1 = hidden1
        self.hidden2 = hidden2
        self.momentum = 0.1
        self.num_iterations = num_iterations
        self.convs = torch.nn.ModuleList()
        self.deltas = deltas
        self.numlayers = num_layers
        self.elasticity = elasticity
        self.heads = 8
        self.concat = True
        
        self.bns = torch.nn.ModuleList()
        for i in range(num_layers-1):
            self.bns.append(BN(self.heads*self.hidden1, momentum=self.momentum))
        self.convs = torch.nn.ModuleList()        
        for i in range(num_layers - 1):
                self.convs.append(GINConv(Sequential(
            Linear( self.heads*self.hidden1,  self.heads*self.hidden1),
            ReLU(),
            Linear( self.heads*self.hidden1,  self.heads*self.hidden1),
            ReLU(),
            BN(self.heads*self.hidden1, momentum=self.momentum),
        ),train_eps=True))
        self.bn1 = BN(self.heads*self.hidden1)       
        self.conv1 = GINConv(Sequential(Linear(self.hidden2,  self.heads*self.hidden1),
            ReLU(),
            Linear( self.heads*self.hidden1,  self.heads*self.hidden1),
            ReLU(),
            BN(self.heads*self.hidden1, momentum=self.momentum),
        ),train_eps=True)

        if self.concat:
            self.lin1 = Linear(self.heads*self.hidden1, self.hidden1)
        else:
            self.lin1 = Linear(self.hidden1, self.hidden1)
        self.lin2 = Linear(self.hidden1, 1)
        self.gnorm = GraphSizeNorm()

                    


    def reset_parameters(self):
        self.conv1.reset_parameters()
        
        for conv in self.convs:
            conv.reset_parameters() 
        for bn in self.bns:
            bn.reset_parameters()
        self.bn1.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()






    def forward(self, data, edge_dropout = None, penalty_coefficient = 0.25):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        num_graphs = batch.max().item() + 1
        row, col = edge_index     
        total_num_edges = edge_index.shape[1]
        N_size = x.shape[0]

        
        if edge_dropout is not None:
            edge_index = dropout_adj(edge_index, edge_attr = (torch.ones(edge_index.shape[1], device=device)).long(), p = edge_dropout, force_undirected=True)[0]
            edge_index = add_remaining_self_loops(edge_index, num_nodes = batch.shape[0])[0]
                
        reduced_num_edges = edge_index.shape[1]
        current_edge_percentage = (reduced_num_edges/total_num_edges)
        no_loop_index,_ = remove_self_loops(edge_index)  
        no_loop_row, no_loop_col = no_loop_index

        xinit= x.clone()
        x = x.unsqueeze(-1)
        mask = get_mask(x,edge_index,1).to(x.dtype)
        x = F.leaky_relu(self.conv1(x, edge_index))# +x
        x = x*mask
        x = self.gnorm(x)
        x = self.bn1(x)
        
            
        for conv, bn in zip(self.convs, self.bns):
            if(x.dim()>1):
                x =  x+F.leaky_relu(conv(x, edge_index))
                mask = get_mask(mask,edge_index,1).to(x.dtype)
                x = x*mask
                x = self.gnorm(x)
                x = bn(x)

        xpostconvs = x.detach()
        #
        x = F.leaky_relu(self.lin1(x)) 
        x = x*mask


        xpostlin1 = x.detach()
        x = F.leaky_relu(self.lin2(x)) 
        x = x*mask


        #calculate min and max
        batch_max = scatter_max(x, batch, 0, dim_size= N_size)[0]
        batch_max = torch.index_select(batch_max, 0, batch)        
        batch_min = scatter_min(x, batch, 0, dim_size= N_size)[0]
        batch_min = torch.index_select(batch_min, 0, batch)

        #min-max normalize
        x = (x-batch_min)/(batch_max+1e-6-batch_min)
        probs=x
           
        x2 = x.detach()              
        deg = degree(row).unsqueeze(-1) 
        totalvol = scatter_add(deg.detach()*torch.ones_like(x, device=device), batch, 0)+1e-6
        totalcard = scatter_add(torch.ones_like(x, device=device), batch, 0)+1e-6               
        x2 =  ((x2 - torch.rand_like(x, device = device))>0).float()    
        vol_1 = scatter_add(probs*deg, batch, 0)+1e-6
        card_1 = scatter_add(probs, batch,0)            
        set_size = scatter_add(x2, batch, 0)
        vol_hard = scatter_add(deg*x2, batch, 0, dim_size = batch.max().item()+1)+1e-6 
        total_vol_ratio = vol_hard/totalvol
        
        
        #calculating the terms for the expected distance between clique and graph
        pairwise_prodsums = torch.zeros(num_graphs, device = device)
        for graph in range(num_graphs):
            batch_graph = (batch==graph)
            pairwise_prodsums[graph] = (torch.conv1d(probs[batch_graph].unsqueeze(-1), probs[batch_graph].unsqueeze(-1))).sum()/2
        
        
        ###calculate loss terms
        self_sums = scatter_add((probs*probs), batch, 0, dim_size = num_graphs)
        expected_weight_G = scatter_add(probs[no_loop_row]*probs[no_loop_col], batch[no_loop_row], 0, dim_size = num_graphs)/2.
        expected_clique_weight = (pairwise_prodsums.unsqueeze(-1) - self_sums)/1.
        expected_distance = (expected_clique_weight - expected_weight_G)        
        
        
        ###useful numbers 
        max_set_weight = (scatter_add(torch.ones_like(x)[no_loop_row], batch[no_loop_row], 0, dim_size = num_graphs)/2).squeeze(-1)                
        set_weight = (scatter_add(x2[no_loop_row]*x2[no_loop_col], batch[no_loop_row], 0, dim_size = num_graphs)/2)+1e-6
        clique_edges_hard = (set_size*(set_size-1)/2) +1e-6
        clique_dist_hard = set_weight/clique_edges_hard
        clique_check = ((clique_edges_hard != clique_edges_hard))
        setedge_check  = ((set_weight != set_weight))      
        
        assert ((clique_dist_hard>=1.1).sum())<=1e-6, "Invalid set vol/clique vol ratio."

        ###calculate loss
        expected_loss = (penalty_coefficient)*expected_distance*0.5 - 0.5*expected_weight_G  
        

        loss = expected_loss


        retdict = {}
        
        retdict["output"] = [probs.squeeze(-1),"hist"]   #output
        retdict["Expected_cardinality"] = [card_1.mean(),"sequence"]
        retdict["Expected_cardinality_hist"] = [card_1,"hist"]
        retdict["losses histogram"] = [loss.squeeze(-1),"hist"]
        retdict["Set sizes"] = [set_size.squeeze(-1),"hist"]
        retdict["volume_hard"] = [vol_hard.mean(),"aux"] #volume2
        retdict["cardinality_hard"] = [set_size[0],"sequence"] #volumeq
        retdict["Expected weight(G)"]= [expected_weight_G.mean(), "sequence"]
        retdict["Expected maximum weight"] = [expected_clique_weight.mean(),"sequence"]
        retdict["Expected distance"]= [expected_distance.mean(), "sequence"]
        retdict["Currvol/Cliquevol"] = [clique_dist_hard.mean(),'sequence']
        retdict["Currvol/Cliquevol all graphs in batch"] = [clique_dist_hard.squeeze(-1),'hist']
        retdict["Average ratio of total volume"]= [total_vol_ratio.mean(),'sequence']
        retdict["cardinalities"] = [cardinalities.squeeze(-1),"hist"]
        retdict["Current edge percentage"] = [torch.tensor(current_edge_percentage),'sequence']
        retdict["loss"] = [loss.mean().squeeze(),"sequence"] #final loss

        return retdict
    
    def __repr__(self):
        return self.__class__.__name__
Beispiel #2
0
class cut_MPNN(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden1, hidden2, deltas, elasticity=0.01, num_iterations = 30):
        super(cut_MPNN, self).__init__()
        self.hidden1 = hidden1
        self.hidden2 = hidden2
        self.conv1 = GINConv(Sequential(
            Linear(1,  self.hidden1),
            ReLU(),
            Linear(self.hidden1, self.hidden1),
            ReLU(),
            BN( self.hidden1),
        ),train_eps=False)
        self.num_iterations = num_iterations
        self.convs = torch.nn.ModuleList()
        self.deltas = deltas
        self.numlayers = num_layers
        self.elasticity = elasticity
        
        self.bns = torch.nn.ModuleList()
        for i in range(num_layers-1):
            self.bns.append(BN( self.hidden1))
        self.convs = torch.nn.ModuleList()        
        for i in range(num_layers - 1):
                self.convs.append(GINConv(Sequential(
            Linear( self.hidden1,  self.hidden1),
            ReLU(),
            Linear( self.hidden1,  self.hidden1),
            ReLU(),
            BN(self.hidden1),
        ),train_eps=False))
     
        self.conv2 = GATAConv( self.hidden1, self.hidden2 ,heads=8)
        self.lin1 = Linear(8*self.hidden2, self.hidden1)
        self.bn2 = BN(self.hidden1)
        self.lin2 = Linear(self.hidden1, 1)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters() 
        for conv in self.convs:
            conv.reset_parameters()    
        for bn in self.bns:
            bn.reset_parameters()
        self.lin1.reset_parameters()
        self.bn2.reset_parameters()
        self.lin2.reset_parameters()


    def forward(self, data, tvol = None):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch 
        xinit= x.clone()
        row, col = edge_index
        mask = get_mask(x,edge_index,1).to(x.dtype).unsqueeze(-1)

        x = self.conv1(x, edge_index)
        xpostconv1 = x.detach() 
        x = x*mask
        for conv, bn in zip(self.convs, self.bns):
            if(x.dim()>1):
                x = x + conv(x, edge_index)
                mask = get_mask(mask,edge_index,1).to(x.dtype)
                x = x*mask
                x = bn(x)


        x = self.conv2(x, edge_index)
        mask = get_mask(mask,edge_index,1).to(x.dtype)
        x = x*mask
        xpostconvs = x.detach()
        #
        x = F.leaky_relu(self.lin1(x)) 
        x = x*mask
        x = self.bn2(x)

        xpostlin1 = x.detach()
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.leaky_relu(self.lin2(x)) 
        x = x*mask
        

        xprethresh = x.detach()
        N_size = x.shape[0]    
        batch_max = scatter_max(x, batch, 0, dim_size= N_size)[0]
        batch_max = torch.index_select(batch_max, 0, batch)
        batch_min = scatter_min(x, batch, 0, dim_size= N_size)[0]
        batch_min = torch.index_select(batch_min, 0, batch)
        
        #min-max normalize       
        x = (x-batch_min)/(batch_max+1e-6-batch_min)
        x = x*mask + mask*1e-6
        

        #add dirac in the set
        x = x + xinit.unsqueeze(-1)
        
        #calculate
        x2 = x.detach()              
        r, c = edge_index
        tv = total_var(x, edge_index, batch)
        deg = degree(r).unsqueeze(-1) 
        conduct_1 = (tv)
        totalvol = scatter_add(deg.detach()*torch.ones_like(x, device=device), batch, 0)+1e-6
        totalcard = scatter_add(torch.ones_like(x, device=device), batch, 0)+1e-6
        
                
        #receptive field
        recvol_hard = scatter_add(deg*mask.float(), batch, 0, dim_size = batch.max().item()+1)+1e-6 
        reccard_hard = scatter_add(mask.float(), batch, 0, dim_size = batch.max().item()+1)+1e-6 
        
        assert recvol_hard.mean()/totalvol.mean() <=1, "Something went wrong! Receptive field is larger than total volume."
        target = torch.zeros_like(totalvol)
        
        #generate target vol
        if tvol is None:
            feasible_vols = data.recfield_vol/data.total_vol-0.0
            target = torch.rand_like(feasible_vols, device=device)*feasible_vols*0.85 + 0.1
            target = target.squeeze(-1)*totalvol.squeeze(-1)
        else:
            target = tvol*totalvol.squeeze(-1)
        a = torch.ones((batch.max().item()+1,1), device = device)
        xfilt = x
                
        
        ###############################################################################
        #iterative rescaling
        counter_no2 = 0
        for iteration in range(self.num_iterations):
            counter_no2 += 1
            keep = (((a[batch]*xfilt)<1).to(x.dtype))

            
            x_k, d_k, d_nk = xfilt*keep*mask, deg*keep*mask, deg*(1-keep)*mask
            
            
            diff = target.unsqueeze(-1) - scatter_add(d_nk, batch, 0)
            dot = scatter_add(x_k*d_k, batch, 0)
            a = diff/(dot+1e-5)
            volcur = (scatter_add(torch.clamp(a[batch]*xfilt,max = 1., min = 0.)*deg,batch,0))

            volcheck = (torch.abs(target - volcur.squeeze(-1))>0.1)
            checki = torch.abs(target.squeeze(-1)-volcur.squeeze(-1))>0.01

            targetcheck = torch.abs(volcur.squeeze(-1) - target)
            
            check = (targetcheck<= self.elasticity*target).to(x.dtype)

            if (tvol is not None):
                pass
            if(check.sum()>=batch.max().item()+1):
                break;
        
        probs = torch.clamp(a[batch]*x*mask, max = 1., min = 0.)
        ###############################################################################

            
            
        #collect useful numbers    
        x2 =  ((probs - torch.rand_like(x, device = device))>0).float()         
        vol_1 = scatter_add(probs*deg, batch, 0)+1e-6
        card_1 = scatter_add(probs, batch,0) 
        rec_field = scatter_add(mask, batch, 0)+1e-6
        cut_size = scatter_add(x2, batch, 0)
        tv_hard = total_var(x2, edge_index, batch)
        vol_hard = scatter_add(deg*x2, batch, 0, dim_size = batch.max().item()+1)+1e-6 
        conduct_hard = tv_hard/vol_hard         
        rec_field_ratio = cut_size/rec_field
        rec_field_volratio = vol_hard/recvol_hard
        total_vol_ratio = vol_hard/totalvol
        
        #calculate loss
        expected_cut = scatter_add(probs*deg, batch, 0) - scatter_add((probs[row]*probs[col]), batch[row], 0)   
        loss = expected_cut   


        #return dict 
        retdict = {}
        retdict["output"] = [probs.squeeze(-1),"hist"]   #output
        #retdict["|Expected_vol - Target|"]= [targetcheck, "sequence"] #absolute distance from targetvol
        retdict["Expected_volume"] = [vol_1.mean(),"sequence"] #volume
        retdict["Expected_cardinality"] = [card_1.mean(),"sequence"]
        retdict["volume_hard"] = [vol_hard.mean(),"sequence"] #volume2
        #retdict["cut1"] = [tv.mean(),"sequence"] #cut1
        retdict["cut_hard"] = [tv_hard.mean(),"sequence"] #cut1
        retdict["Average cardinality ratio of receptive field "] = [rec_field_ratio.mean(),"sequence"] 
        retdict["Recfield volume/Total volume"] = [recvol_hard.mean()/totalvol.mean(), "sequence"]
        retdict["Average ratio of receptive field volume"]= [rec_field_volratio.mean(),'sequence']
        retdict["Average ratio of total volume"]= [total_vol_ratio.mean(),'sequence']
        retdict["mask"] = [mask, "aux"] #mask
        retdict["xinit"] = [xinit,"hist"] #layer input diracs
        retdict["xpostlin1"] = [xpostlin1.mean(1),"hist"] #after first linear layer
        retdict["xprethresh"] = [xprethresh.mean(1),"hist"] #pre thresholding activations 195 x 1
        retdict["lossvol"] = [lossvol.mean(),"sequence"] #volume constraint
        retdict["losscard"] = [losscard.mean(),"sequence"] #cardinality constraint
        retdict["loss"] = [loss.mean().squeeze(),"sequence"] #final loss

        return retdict
    
    def __repr__(self):
        return self.__class__.__name__
Beispiel #3
0
class clique_MPNN(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden1, hidden2, deltas):
        super(clique_MPNN, self).__init__()
        self.hidden1 = hidden1
        self.hidden2 = hidden2
        self.momentum = 0.1
        self.convs = torch.nn.ModuleList()
        self.deltas = deltas
        self.numlayers = num_layers
        self.heads = 8
        self.concat = True

        self.bns = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.bns.append(
                BN(self.heads * self.hidden1, momentum=self.momentum))
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(
                GINConv(Sequential(
                    Linear(self.heads * self.hidden1,
                           self.heads * self.hidden1),
                    ReLU(),
                    Linear(self.heads * self.hidden1,
                           self.heads * self.hidden1),
                    ReLU(),
                    BN(self.heads * self.hidden1, momentum=self.momentum),
                ),
                        train_eps=True))
        self.bn1 = BN(self.heads * self.hidden1)
        self.conv1 = GINConv(Sequential(
            Linear(self.hidden2, self.heads * self.hidden1),
            ReLU(),
            Linear(self.heads * self.hidden1, self.heads * self.hidden1),
            ReLU(),
            BN(self.heads * self.hidden1, momentum=self.momentum),
        ),
                             train_eps=True)

        if self.concat:
            self.lin1 = Linear(self.heads * self.hidden1, self.hidden1)
        else:
            self.lin1 = Linear(self.hidden1, self.hidden1)
        self.lin2 = Linear(self.hidden1, 1)
        self.gnorm = GraphSizeNorm()

    def reset_parameters(self):
        self.conv1.reset_parameters()

        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
        self.bn1.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data, edge_dropout=None, penalty_coefficient=0.25):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        num_graphs = batch.max().item() + 1
        row, col = edge_index
        total_num_edges = edge_index.shape[1]
        N_size = x.shape[0]

        if edge_dropout is not None:
            edge_index = dropout_adj(
                edge_index,
                edge_attr=(torch.ones(edge_index.shape[1],
                                      device=device)).long(),
                p=edge_dropout,
                force_undirected=True)[0]
            edge_index = add_remaining_self_loops(edge_index,
                                                  num_nodes=batch.shape[0])[0]

        reduced_num_edges = edge_index.shape[1]
        current_edge_percentage = (reduced_num_edges / total_num_edges)
        no_loop_index, _ = remove_self_loops(edge_index)
        no_loop_row, no_loop_col = no_loop_index

        xinit = x.clone()
        x = x.unsqueeze(-1)
        mask = get_mask(x, edge_index, 1).to(x.dtype)
        x = F.leaky_relu(self.conv1(x, edge_index))  # +x
        x = x * mask
        x = self.gnorm(x)
        x = self.bn1(x)

        for conv, bn in zip(self.convs, self.bns):
            if (x.dim() > 1):
                x = x + F.leaky_relu(conv(x, edge_index))
                mask = get_mask(mask, edge_index, 1).to(x.dtype)
                x = x * mask
                x = self.gnorm(x)
                x = bn(x)

        xpostconvs = x.detach()
        #
        x = F.leaky_relu(self.lin1(x))
        x = x * mask

        xpostlin1 = x.detach()
        x = F.leaky_relu(self.lin2(x))
        x = x * mask

        #calculate min and max
        batch_max = scatter_max(x, batch, 0, dim_size=N_size)[0]
        batch_max = torch.index_select(batch_max, 0, batch)
        batch_min = scatter_min(x, batch, 0, dim_size=N_size)[0]
        batch_min = torch.index_select(batch_min, 0, batch)

        #min-max normalize
        x = (x - batch_min) / (batch_max + 1e-6 - batch_min)
        probs = x

        #calculating the terms for the expected distance between clique and graph
        pairwise_prodsums = torch.zeros(num_graphs, device=device)
        for graph in range(num_graphs):
            batch_graph = (batch == graph)
            pairwise_prodsums[graph] = (torch.conv1d(
                probs[batch_graph].unsqueeze(-1),
                probs[batch_graph].unsqueeze(-1))).sum() / 2

        ###calculate loss terms
        self_sums = scatter_add((probs * probs), batch, 0, dim_size=num_graphs)
        expected_weight_G = scatter_add(
            probs[no_loop_row] * probs[no_loop_col],
            batch[no_loop_row],
            0,
            dim_size=num_graphs) / 2.
        expected_clique_weight = (pairwise_prodsums.unsqueeze(-1) -
                                  self_sums) / 1.
        expected_distance = (expected_clique_weight - expected_weight_G)

        ###calculate loss
        expected_loss = (penalty_coefficient
                         ) * expected_distance * 0.5 - 0.5 * expected_weight_G

        loss = expected_loss

        retdict = {}

        retdict["output"] = [probs.squeeze(-1), "hist"]  #output
        retdict["losses histogram"] = [loss.squeeze(-1), "hist"]
        retdict["Expected weight(G)"] = [expected_weight_G.mean(), "sequence"]
        retdict["Expected maximum weight"] = [
            expected_clique_weight.mean(), "sequence"
        ]
        retdict["Expected distance"] = [expected_distance.mean(), "sequence"]
        retdict["loss"] = [loss.mean().squeeze(), "sequence"]  #final loss

        return retdict

    def __repr__(self):
        return self.__class__.__name__