Exemple #1
0
def scatter_logsumexp(src: torch.Tensor,
                      index: torch.Tensor,
                      dim: int = -1,
                      out: Optional[torch.Tensor] = None,
                      dim_size: Optional[int] = None,
                      eps: float = 1e-12) -> torch.Tensor:
    if not torch.is_floating_point(src):
        raise ValueError('`scatter_logsumexp` can only be computed over '
                         'tensors with floating point data types.')

    index = broadcast(index, src, dim)

    if out is not None:
        dim_size = out.size(dim)
    else:
        if dim_size is None:
            dim_size = int(index.max()) + 1

    size = src.size()
    size[dim] = dim_size
    max_value_per_index = torch.full(size,
                                     float('-inf'),
                                     dtype=src.dtype,
                                     device=src.device)
    scatter_max(src, index, dim, max_value_per_index, dim_size)[0]
    max_per_src_element = max_value_per_index.gather(dim, index)
    recentered_scores = src - max_per_src_element

    if out is not None:
        out = out.sub_(max_per_src_element).exp_()

    sum_per_index = scatter_sum(recentered_scores.exp_(), index, dim, out,
                                dim_size)

    return sum_per_index.add_(eps).log_().add_(max_value_per_index)
Exemple #2
0
    def message_step(self, node_network, edge_network, x, start, end, e):

        # Compute new node features
        edge_messages = scatter_add(e, end, dim=0, dim_size=x.shape[0])

        if self.hparams["aggregation"] == "sum":
            edge_messages = scatter_add(e, end, dim=0, dim_size=x.shape[0])

        elif self.hparams["aggregation"] == "max":
            edge_messages = scatter_max(e, end, dim=0, dim_size=x.shape[0])[0]

        elif self.hparams["aggregation"] == "sum_max":
            edge_messages = torch.cat(
                [
                    scatter_max(e, end, dim=0, dim_size=x.shape[0])[0],
                    scatter_add(e, end, dim=0, dim_size=x.shape[0]),
                ],
                dim=-1,
            )
        node_inputs = torch.cat([x, edge_messages], dim=-1)

        x_out = node_network(node_inputs)

        x_out += x

        # Compute new edge features
        edge_inputs = torch.cat([x[start], x[end], e], dim=-1)
        e_out = edge_network(edge_inputs)

        e_out += e

        return x_out, e_out
Exemple #3
0
    def message_step(self, x, start, end):

        edge_inputs = torch.cat([x[start], x[end]], dim=1)
        e = self.edge_network(edge_inputs)
        e = torch.sigmoid(e)

        if self.hparams["aggregation"] == "sum":
            messages = scatter_add(e * x[start],
                                   end,
                                   dim=0,
                                   dim_size=x.shape[0])

        elif self.hparams["aggregation"] == "max":
            messages = scatter_max(e * x[start],
                                   end,
                                   dim=0,
                                   dim_size=x.shape[0])[0]

        elif self.hparams["aggregation"] == "sum_max":
            messages = torch.cat(
                [
                    scatter_max(e * x[start], end, dim=0,
                                dim_size=x.shape[0])[0],
                    scatter_add(e * x[start], end, dim=0, dim_size=x.shape[0]),
                ],
                dim=-1,
            )

        node_inputs = torch.cat([messages, x], dim=1)
        x_out = self.node_network(node_inputs)

        x_out += x

        return x_out
Exemple #4
0
def get_aggregation(aggregation):
    """
    Factory dictionary for aggregation depending on the hparams["aggregation"]
    """

    aggregation_dict = {
        "sum": lambda e, end, x: scatter_add(e, end, dim=0, dim_size=x.shape[0]),
        "mean": lambda e, end, x: scatter_mean(e, end, dim=0, dim_size=x.shape[0]),
        "max": lambda e, end, x: scatter_max(e, end, dim=0, dim_size=x.shape[0])[0],
        "sum_max": lambda e, end, x: torch.cat(
            [
                scatter_max(e, end, dim=0, dim_size=x.shape[0])[0],
                scatter_add(e, end, dim=0, dim_size=x.shape[0]),
            ],
            dim=-1,
        ),
        "mean_sum": lambda e, end, x: torch.cat(
            [
                scatter_mean(e, end, dim=0, dim_size=x.shape[0]),
                scatter_add(e, end, dim=0, dim_size=x.shape[0]),
            ],
            dim=-1,
        ),
        "mean_max": lambda e, end, x: torch.cat(
            [
                scatter_max(e, end, dim=0, dim_size=x.shape[0])[0],
                scatter_mean(e, end, dim=0, dim_size=x.shape[0]),
            ],
            dim=-1,
        ),
    }

    return aggregation_dict[aggregation]
Exemple #5
0
    def forward(self, x, edge_index, edge_attr, u, batch):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        row, col = edge_index
        out = torch.cat([x[col], edge_attr], dim=1)
        out = self.node_mlp_1(out)

        if self.aggregation == "mean":
            out = scatter_mean(out, row, dim=0, dim_size=x.size(0))
        elif self.aggregation == "min":
            out, _ = scatter_min(out, row, dim=0, dim_size=x.size(0))
        elif self.aggregation == "max":
            out, _ = scatter_max(out, row, dim=0, dim_size=x.size(0))
        elif self.aggregation == "minmax":
            out = torch.cat([
                scatter_min(out, row, dim=0, dim_size=x.size(0))[0],
                scatter_max(out, row, dim=0, dim_size=x.size(0))[0]
            ], dim=1)
        else:
            raise ValueError("Unknown aggregation type: {}".format(self.aggregation)) 

        out = torch.cat([x, out, u[batch]], dim=1)
        return self.node_mlp_2(out)
Exemple #6
0
    def forward(self, x, x_privileged, action_masks):
        batch_size = x.size()[0]
        x, active_agents, (pitems, pmask) = self.latents(x, action_masks)

        if x.is_cuda:
            vin = torch.cuda.FloatTensor(batch_size, self.d_agent *
                                         self.hps.dff_ratio).fill_(0)
        else:
            vin = torch.zeros(batch_size, self.d_agent * self.hps.dff_ratio)
        scatter_max(x, index=active_agents.batch_index, dim=0, out=vin)
        if self.hps.use_privileged:
            mask1k = 1000.0 * pmask.float().unsqueeze(-1)
            pitems_max = (pitems - mask1k).max(dim=1).values
            pitems_max[pitems_max == -1000.0] = 0.0
            pitems_avg = pitems.sum(dim=1) / torch.clamp_min(
                (~pmask).float().sum(dim=1), min=1).unsqueeze(-1)
            vin = torch.cat([vin, pitems_max, pitems_avg], dim=1)
        values = self.value_head(vin).view(-1)

        logits = self.policy_head(x)
        logits = logits.masked_fill(
            action_masks.reshape(-1,
                                 self.naction)[active_agents.flat_index] == 0,
            float('-inf'))
        probs = F.softmax(logits, dim=1)
        probs = active_agents.pad(probs)
        return probs, values
Exemple #7
0
    def forward(self, x, edge_index):

        input_x = x

        x = self.node_encoder(x)
        #         x = F.softmax(x, dim=-1)

        start, end = edge_index

        for i in range(self.hparams["n_graph_iters"]):

            x_initial = x

            messages_in, _ = scatter_max(x[start], end, dim=0, dim_size=x.shape[0])
            messages_out, _ = scatter_max(x[end], start, dim=0, dim_size=x.shape[0])

            message_stack = torch.stack((messages_in, messages_out), dim=0)
            messages, _ = torch.max(message_stack, 0)

            #             messages = scatter_add(x[start], end, dim=0, dim_size=x.shape[0]) + scatter_add(x[end], start, dim=0, dim_size=x.shape[0])

            node_inputs = torch.cat([x, messages], dim=-1)
            #             node_inputs = F.softmax(node_inputs, dim=-1)

            x = self.node_network(node_inputs)

            x = x + x_initial

        edge_inputs = torch.cat([x[start], x[end]], dim=1)
        return self.edge_network(edge_inputs)
Exemple #8
0
    def forward(self, x, x_privileged, action_masks):
        batch_size = x.size()[0]
        x, indices, groups, (pitems,
                             pmask) = self.latents(x, x_privileged,
                                                   action_masks)

        if x.is_cuda:
            vin = torch.cuda.FloatTensor(batch_size, self.d_agent *
                                         self.hps.dff_ratio).fill_(0)
        else:
            vin = torch.zeros(batch_size, self.d_agent * self.hps.dff_ratio)
        scatter_max(x, index=groups, dim=0, out=vin)
        if self.hps.use_privileged:
            mask1k = 1000.0 * pmask.float().unsqueeze(-1)
            pitems_max = (pitems - mask1k).max(dim=1).values
            pitems_max[pitems_max == -1000.0] = 0.0
            pitems_avg = pitems.sum(dim=1) / torch.clamp_min(
                (~pmask).float().sum(dim=1), min=1).unsqueeze(-1)
            vin = torch.cat([vin, pitems_max, pitems_avg], dim=1)
        values = self.value_head(vin).view(-1)

        logits = self.policy_head(x)

        if x.is_cuda:
            padded_logits = torch.cuda.FloatTensor(batch_size * self.agents,
                                                   self.naction).fill_(0)
        else:
            padded_logits = torch.zeros(batch_size * self.agents, self.naction)
        scatter_add(logits, index=indices, dim=0, out=padded_logits)
        padded_logits = padded_logits.view(batch_size, self.agents,
                                           self.naction)
        probs = F.softmax(padded_logits, dim=2)

        return probs, values
Exemple #9
0
    def pathat_layer(self, input, pathM, pathlens, eluF=True):
        N = input.size()[0]
        pathh = torch.mm(input, self.pathW)
        pathh = pathh+self.pathbias                # h: N x out
        
        if not self.concat:  # if the last layer
            pathlens = [2]
        
        pathfeat_all = None
        for pathlen_iter in pathlens:
            i = pathM[ pathlen_iter ]['indices']
            v = pathM[ pathlen_iter ]['values']
            featlen = pathh.shape[1]
            pathlen = v.shape[1]
            pathfeat = tuple( (pathh[v[:,i], :] for i in range(1,pathlen)) )
            pathfeat = torch.cat(pathfeat, dim=1)
            pathfeat = pathfeat.view(-1,pathlen-1,featlen)
            pathfeat, _ = torch.max(pathfeat, dim=1)    # seems max is better?
            #pathfeat = torch.mean(pathfeat, dim=1)     #
            att_feat = torch.cat( (pathfeat, pathh[i[0,:],:]), dim=1 ).t()
            if pathlen_iter==2:
                path_att = self.leakyrelu(self.patha_2.mm(att_feat).squeeze())
            else:
                path_att = self.leakyrelu(self.patha_3.mm(att_feat).squeeze())    
            # softmax of p_a -> p_a_e
            path_att = path_att - scatter_max(path_att, i[0,:], dim=0, dim_size=N)[0][i[0,:]]
            path_att = path_att.exp()
            path_att = path_att / (scatter_add(path_att, i[0,:], dim=0, dim_size=N)[i[0,:]] \
                                    + torch.Tensor([9e-15]).cuda())
            path_att = path_att.view(-1,1)
            path_att = self.dropout(path_att)         # add dropout here of p_a_e
            w_pathfeat = torch.mul(pathfeat, path_att)
            h_path_prime = scatter_add(w_pathfeat, i[0,:], dim=0)
            # h_path_prime is the feature embedded from paths  N*feat
            
            if pathfeat_all is None:
                pathfeat_all = h_path_prime
            else:
                pathfeat_all = torch.cat((pathfeat_all, h_path_prime), dim=0)

        if len(pathlens)==2:
            leni = torch.tensor(np.array(list(range(N))+list(range(N)))).cuda()
            
            att_feat = torch.cat( (pathfeat_all, pathh[leni,:]), dim=1 ).t()
            path_att = self.leakyrelu(self.lenAtt.mm(att_feat).squeeze())
            # softmax of p_a -> p_a_e
            path_att = path_att - scatter_max(path_att, leni, dim=0, dim_size=N)[0][leni]
            path_att = path_att.exp()
            path_att = path_att / (scatter_add(path_att, leni, dim=0, dim_size=N)[leni] \
                                    + torch.Tensor([9e-15]).cuda())
            path_att = path_att.view(-1,1)
            # path_att = self.dropout(path_att)         # add dropout here of p_a_e
            w_pathfeat = torch.mul(pathfeat_all, path_att)
            h_path_prime = scatter_add(w_pathfeat, leni, dim=0)

        if self.concat and eluF:
            return F.elu( h_path_prime )
        else:
            return h_path_prime
Exemple #10
0
    def multi_hop(self, triple_prob, distance, head, tail, concept_label, triple_label, gamma=0.8, iteration=3,
                       method="avg"):
        '''
        triple_prob: bsz x L x mem_t
        distance: bsz x mem
        head, tail: bsz x mem_t
        concept_label: bsz x mem
        triple_label: bsz x mem_t

        Init binary vector with source concept == 1 and others 0
        expand to size: bsz x L x mem
        '''
        concept_probs = []

        cpt_size = (triple_prob.size(0), triple_prob.size(1), distance.size(1))
        init_mask = torch.zeros_like(distance).unsqueeze(1).expand(*cpt_size).to(distance.device).float()
        init_mask.masked_fill_((distance == 0).unsqueeze(1), 1)
        final_mask = init_mask.clone()

        init_mask.masked_fill_((concept_label == -1).unsqueeze(1), 0)
        concept_probs.append(init_mask)

        head = head.unsqueeze(1).expand(triple_prob.size(0), triple_prob.size(1), -1)
        tail = tail.unsqueeze(1).expand(triple_prob.size(0), triple_prob.size(1), -1)

        for step in range(iteration):
            '''
            Calculate triple head score
            '''
            node_score = concept_probs[-1]
            triple_head_score = node_score.gather(2, head)
            triple_head_score.masked_fill_((triple_label == -1).unsqueeze(1), 0)
            '''
            Method: 
                - avg:
                    s(v) = Avg_{u \in N(v)} gamma * s(u) + R(u->v) 
                - max: 
                    s(v) = max_{u \in N(v)} gamma * s(u) + R(u->v)
            '''
            update_value = triple_head_score * gamma + triple_prob
            out = torch.zeros_like(node_score).to(node_score.device).float()
            if method == "max":
                scatter_max(update_value, tail, dim=-1, out=out)
            elif method == "avg":
                scatter_mean(update_value, tail, dim=-1, out=out)
            out.masked_fill_((concept_label == -1).unsqueeze(1), 0)

            concept_probs.append(out)

        '''
        Natural decay of concept that is multi-hop away from source
        '''
        total_concept_prob = final_mask * -1e5
        for prob in concept_probs[1:]:
            total_concept_prob += prob
        # bsz x L x mem

        return total_concept_prob
    def aggregate_multi(self, ind_lst, pos_lst, row_rel, query_rel):
        u"""
        Args:
            ind_lst (list) -- List of indices of entity (tuple) to which each row of row_rel belongs.
            pos_lst (list) -- List of indices of batch to which each row of row_rel targets.
            row_rel (Tensor) -- (n_rel x emb_dim)
            query_rel (Tensor) -- (n_batch x emb_dim)
        Return:
            out (Tensor) -- (n_batch x emb_dim)
        """
        device = row_rel.device
        mode = self.args.aggregation

        ind_lst = torch.LongTensor(ind_lst).to(device)
        pos_lst = torch.LongTensor(pos_lst).to(device)

        if mode == "attention":
            row_query = query_rel[pos_lst]

            # Scores between each row relation and query relations.
            o = torch.sum(row_rel * row_query, dim=1, keepdim=True)
            # Calculate maximum score for each entity (tuple) to circumvent exp overflow.
            min_o = torch.min(o).item()
            m,_ = scatter_max(o-min_o, ind_lst, dim=0)
            m = m + min_o
            m = m[ind_lst] #(n_row x 1)
            # Calculate weight for each row relations.
            a = torch.exp(o - m)
            sum_a = scatter_add(a, ind_lst, dim=0)
            sum_a = sum_a[ind_lst]
            w = a / sum_a #(n_row x 1)
            # Calculate weighted mean for each batch.
            weighted_row_rel = w * row_rel
            out = scatter_add(weighted_row_rel, ind_lst, dim=0)
            #(n_batch x emb_dim)
        elif mode == "scaled-attention":
            row_query = query_rel[pos_lst]

            # Scores between each row relation and query relations.
            o = torch.sum(row_rel * row_query, dim=1, keepdim=True)
            o = o / row_rel.size(1)**0.5 #(n_row x 1)
            # Calculate maximum score for each entity (tuple) to circumvent exp overflow.
            min_o = torch.min(o).item()
            m,_ = scatter_max(o-min_o, ind_lst, dim=0)
            m = m + min_o
            m = m[ind_lst] #(n_row x 1)
            # Calculate weight for each row relations.
            a = torch.exp(o - m)
            sum_a = scatter_add(a, ind_lst, dim=0)
            sum_a = sum_a[ind_lst]
            w = a / sum_a #(n_row x 1)
            # Calculate weighted mean for each batch.
            weighted_row_rel = w * row_rel
            out = scatter_add(weighted_row_rel, ind_lst, dim=0)
            #(n_batch x emb_dim)

        return out
    def run_inner_loop(self, inputs):

        x, start, end = inputs

        #         print("2:", torch.cuda.max_memory_allocated() / 1024**3)
        #         torch.cuda.reset_peak_memory_stats()
        # Apply edge network
        edge_inputs = torch.cat([x[start], x[end]], dim=1)
        e = self.edge_network(edge_inputs)
        e = torch.sigmoid(e)

        #         print("3:", torch.cuda.max_memory_allocated() / 1024**3)
        #         torch.cuda.reset_peak_memory_stats()

        if self.hparams["aggregation"] == "sum":
            messages = torch.cat(
                [
                    scatter_add(e * x[start], end, dim=0, dim_size=x.shape[0]),
                    scatter_add(e * x[end], start, dim=0, dim_size=x.shape[0]),
                ],
                dim=-1,
            )

        elif self.hparams["aggregation"] == "max":
            messages = torch.cat(
                [
                    scatter_max(e * x[start], end, dim=0, dim_size=x.shape[0])[0],
                    scatter_max(e * x[end], start, dim=0, dim_size=x.shape[0])[0],
                ],
                dim=-1,
            )

        elif self.hparams["aggregation"] == "sum_max":
            messages = torch.cat(
                [
                    scatter_max(e * x[start], end, dim=0, dim_size=x.shape[0])[0],
                    scatter_add(e * x[start], end, dim=0, dim_size=x.shape[0]),
                    scatter_max(e * x[end], start, dim=0, dim_size=x.shape[0])[0],
                    scatter_add(e * x[end], start, dim=0, dim_size=x.shape[0]),
                ],
                dim=-1,
            )

        #         print("4:", torch.cuda.max_memory_allocated() / 1024**3)
        #         torch.cuda.reset_peak_memory_stats()

        node_inputs = torch.cat([messages, x], dim=1)
        x = self.node_network(node_inputs)

        return x
Exemple #13
0
 def _add_to_memory(self, key, payload, batch_ids=None):
     num_add = key.size(0)
     self._key[self._num_items:self._num_items + num_add] = key
     for (payload, add_pl) in zip(self._payload, payload):
         payload[self._num_items:self._num_items + num_add] = add_pl
     if self.batch_size > 1:
         self._batch_ids[self._num_items:self._num_items +
                         num_add] = batch_ids
     self._num_items += num_add
     if self.batch_size > 1 and (self._batch_num_items_lb <
                                 self.capacity).any():
         # Add number of items per batch entry, and see if for some entries we exceed the capacity
         # for the first time so we can define a bound
         new_batch_num_items = self._batch_num_items_lb + bincount(
             batch_ids, minlength=self.batch_size)
         new_bound = (new_batch_num_items >= self.capacity) & (
             self._batch_num_items_lb < self.capacity)
         self._batch_num_items_lb = new_batch_num_items
         if new_bound.any():
             # Note: this is not a very tight bound, we can get a tighter bound by taking the max of the first
             # 'capacity' values only per instance, instead of all values (as in the without batch case)
             # Better is to take the k-th largest value, which is not done here (to save computation)
             # but when the queue is actually reduced
             self.bound = torch.where(
                 new_bound,
                 scatter_max(self._key[:self._num_items],
                             self._batch_ids[:self._num_items].long(),
                             dim_size=self.batch_size)[0], self.bound)
     elif self._num_items >= self.capacity and self.bound is None and self.batch_size == 1:
         # As soon as we have more than capacity items for the first time, we can define a bound
         # using the first 'capacity' items (gives slightly better bound than using _num_items items)
         self.bound = self._key[:self.capacity].max().cpu()
Exemple #14
0
    def gat_layer(self, input, adj, genPath=False, eluF=True):
        N = input.size()[0]
        edge = adj._indices()
        h = torch.mm(input, self.W)
        h = h+self.bias                # h: N x out

        # Self-attention on the nodes - Shared attention mechanism
        edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t()     # edge_h: 2*D x E
        edge_att = self.a.mm(edge_h).squeeze()
        edge_e_a = self.leakyrelu(edge_att)     # edge_e_a: E   attetion score for each edge
        if genPath:
            with torch.no_grad():
                edge_weight = edge_e_a
                p_a_e = edge_weight - scatter_max(edge_weight, edge[0,:], dim=0, dim_size=N)[0][edge[0,:]]
                p_a_e = p_a_e.exp()
                p_a_e = p_a_e / (scatter_add(p_a_e, edge[0,:], dim=0, dim_size=N)[edge[0,:]]\
                                    +torch.Tensor([9e-15]).cuda())
                
                scisp = convert.to_scipy_sparse_matrix(edge, p_a_e, N)
                scipy.sparse.save_npz(os.path.join(genPath, 'attmat_{:s}.npz'.format(self.layerN)), scisp)

        edge_e = torch.exp(edge_e_a - torch.max(edge_e_a))                  # edge_e: E
        e_rowsum = spmm(edge, edge_e, N, torch.ones(size=(N,1)).cuda())     # e_rowsum: N x 1
        edge_e = self.dropout(edge_e)       # add dropout improve from 82.4 to 83.8
        # edge_e: E
        
        h_prime = spmm(edge, edge_e, N, h)
        h_prime = h_prime.div(e_rowsum+torch.Tensor([9e-15]).cuda())        # h_prime: N x out
        
        if self.concat and eluF:
            return F.elu(h_prime)
        else:
            return h_prime
    def forward(self, x_scores, x_relations, berts, edge_indices,
                softmax_edge_indices, n_program, max_y_score_len):
        feature = self.getFeature(x_scores)
        if self.reduce == 'max':
            feature = torch.cat([
                feature,
                torch.zeros(size=[1, feature.shape[1]]).to(feature.device)
            ],
                                dim=0)
        expanded_feature = feature[edge_indices[0]]
        gathered_feature, _ = self.gatherFeature(
            expanded_feature,
            edge_indices[1],
            dim_size=torch.max(edge_indices[1]).long() + 1,
            dim=0)
        processed_feature = self.processFeature(gathered_feature)
        processed_feature = torch.cat([processed_feature, berts], dim=-1)
        score = self.getScore(processed_feature)
        assert not (score == 0).any()
        score, _ = torch_scatter.scatter_max(score,
                                             softmax_edge_indices[1],
                                             dim_size=n_program *
                                             max_y_score_len,
                                             dim=0)  #! to update here
        score = torch.where(score != 0, score,
                            torch.tensor(-float('inf')).to(score.device))
        score = torch.reshape(score, [n_program, max_y_score_len])
        relation = self.getRelation(x_relations)

        return score, relation
Exemple #16
0
 def softmax(self, x, index, num=None):
     x = x - torch_scatter.scatter_max(x, index, dim=0,
                                       dim_size=num)[0][index]
     x = x.exp()
     x = x / (torch_scatter.scatter_add(x, index, dim=0,
                                        dim_size=num)[index] + 1e-16)
     return x
Exemple #17
0
    def forward(self, pt_fea, xy_ind):
        cur_dev = pt_fea[0].get_device()
        # concate everything
        cat_pt_ind = []
        for i_batch in range(len(xy_ind)):
            cat_pt_ind.append(F.pad(xy_ind[i_batch], (1, 0), 'constant', value=i_batch))

        cat_pt_fea = torch.cat(pt_fea, dim=0)
        cat_pt_ind = torch.cat(cat_pt_ind, dim=0)
        pt_num = cat_pt_ind.shape[0]

        # shuffle the data
        shuffled_ind = torch.randperm(pt_num, device=cur_dev)
        cat_pt_fea = cat_pt_fea[shuffled_ind, :]
        cat_pt_ind = cat_pt_ind[shuffled_ind, :]

        # unique xy grid index
        unq, unq_inv, unq_cnt = torch.unique(cat_pt_ind, return_inverse=True, return_counts=True, dim=0)
        unq = unq.type(torch.int64)

        # process feature

        processed_cat_pt_fea = self.PPmodel(cat_pt_fea)
        pooled_data = torch_scatter.scatter_max(processed_cat_pt_fea, unq_inv, dim=0)[0]

        if self.fea_compre:
            processed_pooled_data = self.fea_compression(pooled_data)
        else:
            processed_pooled_data = pooled_data

        return unq, processed_pooled_data
def softmax(src, index, num_nodes=None):
    r"""Sparse softmax of all values from the :attr:`src` tensor at the indices
    specified in the :attr:`index` tensor along the first dimension.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements for applying the softmax.
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)

    :rtype: :class:`Tensor`
    """
    # print('src = ',src.size(), ' index = ',index.size())

    if num_nodes is None:
        # num_nodes = maybe_num_nodes(index, num_nodes)
        num_nodes = index.max() + 1

    out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index]
    # print('src - scatter_max = ',out.size(),out)
    out = out.exp()
    # print('exp = ',out.size(),out, 'index size = ',index.size())
    out = out / (scatter_add(out, index, dim=0, dim_size=num_nodes)[index] +
                 1e-16)
    # print('out = ',out.size(),out)
    return out
Exemple #19
0
    def forward(self, x, edge_index, edge_weight=None, batch=None):
        
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        # NxF
        x = x.unsqueeze(-1) if x.dim() == 1 else x
        # Add Self Loops
        fill_value = 1
        num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
        edge_index, edge_weight = add_remaining_self_loops(edge_index=edge_index, edge_weight=edge_weight, 
            fill_value=fill_value, num_nodes=num_nodes.sum())

        N = x.size(0) # total num of nodes in batch

        # ExF
        x_pool = self.gnn_intra_cluster(x=x, edge_index=edge_index, edge_weight=edge_weight)
        x_pool_j = x_pool[edge_index[1]]
        x_j = x[edge_index[1]]
        
        #---Master query formation---
        # NxF
        X_q, _ = scatter_max(x_pool_j, edge_index[0], dim=0)
        # NxF
        M_q = self.lin_q(X_q)    
        # ExF
        M_q = M_q[edge_index[0].tolist()]

        score = self.gat_att(torch.cat((M_q, x_pool_j), dim=-1))
        score = F.leaky_relu(score, self.negative_slope)
        score = softmax(score, edge_index[0], num_nodes=num_nodes.sum())

        # Sample attention coefficients stochastically.
        score = F.dropout(score, p=self.dropout_att, training=self.training)
        # ExF
        v_j = x_j * score.view(-1, 1)
        #---Aggregation---
        # NxF
        out = scatter_add(v_j, edge_index[0], dim=0)
        
        #---Cluster Selection
        # Nx1
        fitness = torch.sigmoid(self.gnn_score(x=out, edge_index=edge_index)).view(-1)
        perm = topk(x=fitness, ratio=self.ratio, batch=batch)
        x = out[perm] * fitness[perm].view(-1, 1)
        
        #---Maintaining Graph Connectivity
        batch = batch[perm]
        edge_index, edge_weight = graph_connectivity(
            device = x.device,
            perm=perm,
            edge_index=edge_index,
            edge_weight=edge_weight,
            score=score,
            ratio=self.ratio,
            batch=batch,
            N=N)
 
        
        return x, edge_index, edge_weight, batch, perm
Exemple #20
0
    def forward(self, x, pos, batch):
        # FPS sampling
        id_clusters = fps(pos, ratio=self.ratio, batch=batch)

        # compute for each cluster the k nearest points
        sub_batch = batch[id_clusters] if batch is not None else None

        # beware of self loop
        id_k_neighbor = knn(pos,
                            pos[id_clusters],
                            k=self.k,
                            batch_x=batch,
                            batch_y=sub_batch)

        # transformation of features through a simple MLP
        x = self.mlp(x)

        # Max pool onto each cluster the features from knn in points
        x_out, _ = scatter_max(x[id_k_neighbor[1]],
                               id_k_neighbor[0],
                               dim_size=id_clusters.size(0),
                               dim=0)

        # keep only the clusters and their max-pooled features
        sub_pos, out = pos[id_clusters], x_out
        return out, sub_pos, sub_batch
Exemple #21
0
    def forward(self, graph_data, correct_candidate_idxs):
        gnn_output: GnnOutput = self._gnn(**graph_data)

        # Code assumes that there is one slot per-graph, which is true for the original data

        candidate_node_representations = gnn_output.output_node_representations[
            gnn_output.node_idx_references[
                "candidate_nodes"]]  # [num_candidate_nodes, H_out]
        candidate_nodes_slot_idx = gnn_output.node_graph_idx_reference[
            "candidate_nodes"]  # [num_candidate_nodes]

        slot_representations = gnn_output.output_node_representations[
            gnn_output.
            node_idx_references["slot_node_idx"]]  # [num_slot_nodes, H_out]
        slot_representations_per_candidate = slot_representations[
            candidate_nodes_slot_idx]  # [num_candidate_nodes, H]
        candidate_scores = self.__candidate_scores(
            torch.cat((candidate_node_representations,
                       slot_representations_per_candidate),
                      dim=-1)).squeeze(-1)
        candidate_nodes_logprobs = scatter_log_softmax(
            src=candidate_scores, index=candidate_nodes_slot_idx, eps=0)

        with torch.no_grad():
            self.__sum_acc += int((scatter_max(
                candidate_scores,
                index=candidate_nodes_slot_idx)[1] == correct_candidate_idxs
                                   ).sum())
            self.__num_samples += int(slot_representations.shape[0])
        return -candidate_nodes_logprobs[correct_candidate_idxs].mean()
    def decode(self, y, ext_x, prev_states, prev_context, encoder_features,
               encoder_mask):
        # forward one step lstm
        # y : [b]
        embedded = self.embedding(y.unsqueeze(1))
        lstm_inputs = self.reduce_layer(torch.cat([embedded, prev_context], 2))
        output, states = self.lstm(lstm_inputs, prev_states)

        context, energy = self.attention(output, encoder_features,
                                         encoder_mask)
        concat_input = torch.cat((output, context), 2).squeeze(1)
        logit_input = torch.tanh(self.concat_layer(concat_input))
        logit = self.logit_layer(logit_input)  # [b, |V|]

        if config.use_pointer:
            batch_size = y.size(0)
            num_oov = max(torch.max(ext_x - self.vocab_size + 1), 0)
            zeros = torch.zeros((batch_size, num_oov), device=config.device)
            extended_logit = torch.cat([logit, zeros], dim=1)
            out = torch.zeros_like(extended_logit) - INF
            out, _ = scatter_max(energy, ext_x, out=out)
            out = out.masked_fill(out == -INF, 0)
            logit = extended_logit + out
            logit = logit.masked_fill(logit == -INF, 0)
            # forcing UNK prob 0
            logit[:, UNK_ID] = -INF

        return logit, states, context
Exemple #23
0
    def softmax_weights(self):
        """Compute the softmax of the outgoing edge weights by node.
        Use the shift property of softmax for stability.

        Returns:
            Graph: new Graph
        """
        max_out_weight_per_node, _ = scatter_max(
            src=self.edges,
            index=self.senders,
            dim=0,
            dim_size=self.n_node,
            # fill_value=-1e20
        )
        shifted_weights = self.edges - max_out_weight_per_node[self.senders]

        exp_weights = shifted_weights.exp()
        normalizer = scatter_add(src=exp_weights,
                                 index=self.senders,
                                 dim=0,
                                 dim_size=self.n_node)
        sender_normalizer = normalizer[self.senders]
        normalized_weights = exp_weights / sender_normalizer

        if any_nan(normalized_weights):
            logging.warning(
                "NaN weight after normalization in graph `softmax_weights`")

        return self.update(edges=normalized_weights)
Exemple #24
0
    def max_edge_weight_per_node(self) -> torch.Tensor:
        """Returns weight and edge_idx of the max weight outgoing edge for each node

        Returns:
            torch.Tensor: [n_node, ] weights, [n_node, ] indices
        """
        return scatter_max(self.edges.squeeze(), self.senders)
Exemple #25
0
def real_softmax(src, index, num_nodes=None):
    r"""Computes a sparsely evaluated softmax.
    Given a value tensor :attr:`src`, this function first groups the values
    along the first dimension based on the indices specified in :attr:`index`,
    and then proceeds to compute the softmax individually for each group.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements for applying the softmax.
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)

    :rtype: :class:`Tensor`
    """

    num_nodes = maybe_num_nodes(index, num_nodes)

    src = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index]
    out = src.exp()
    assert not nan_or_inf(out)
    oout = out / (scatter_add(out, index, dim=0, dim_size=num_nodes)[index] +
                  1e-16)
    assert not nan_or_inf(oout)

    return oout
Exemple #26
0
    def __call__(self, data, norm=True):
        row, col = data.edge_index
        N = data.num_nodes

        deg = degree(row, N, dtype=torch.float)
        if norm:
            deg = deg / deg.max()
        deg_col = deg[col]

        min_deg, _ = scatter_min(deg_col, row, dim_size=N)
        min_deg[min_deg > 10000] = 0
        max_deg, _ = scatter_max(deg_col, row, dim_size=N)
        max_deg[max_deg < -10000] = 0
        mean_deg = scatter_mean(deg_col, row, dim_size=N)
        std_deg = scatter_std(deg_col, row, dim_size=N)

        x = torch.stack([deg, min_deg, max_deg, mean_deg, std_deg], dim=1)

        if data.x is not None:
            data.x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x
            data.x = torch.cat([data.x, x], dim=-1)
        else:
            data.x = x

        return data
Exemple #27
0
    def forward(self, x, edge_index):
        edge_index, _ = remove_self_loops(edge_index)
        edge_index = add_self_loops(edge_index, num_nodes=x.size(0))

        x = x.unsqueeze(-1) if x.dim() == 1 else x
        row, col = edge_index

        if self.pool == 'mean':
            out = torch.matmul(x, self.weight)
            if self.bias is not None:
                out = out + self.bias
            out = self.act(out)
            out = scatter_mean(out[col], row, dim=0, dim_size=out.size(0))

        elif self.pool == 'max':
            out = torch.matmul(x, self.weight)
            if self.bias is not None:
                out = out + self.bias
            out = self.act(out)
            out, _ = scatter_max(out[col], row, dim=0, dim_size=out.size(0))

        elif self.pool == 'add':
            out = torch.matmul(x, self.weight)
            if self.bias is not None:
                out = out + self.bias
            out = self.act(out)
            out = scatter_add(x[col], row, dim=0, dim_size=x.size(0))
        else:
            print('pooling not defined!')

        if self.normalize:
            out = F.normalize(out, p=2, dim=-1)

        return out
Exemple #28
0
 def forward(self, data):
     x_1 = self.gcu_1(data.pos, data.tpl_edge_index, data.geo_edge_index)
     x_2 = self.gcu_2(x_1, data.tpl_edge_index, data.geo_edge_index)
     x_3 = self.gcu_3(x_2, data.tpl_edge_index, data.geo_edge_index)
     x_4 = self.mlp_glb(torch.cat([x_1, x_2, x_3], dim=1))
     x_global, _ = scatter_max(x_4, data.batch, dim=0)
     return x_global
Exemple #29
0
    def __get_updated_memory__(self, n_id):
        self.__assoc__[n_id] = torch.arange(n_id.size(0), device=n_id.device)

        # Compute messages (src -> dst).
        msg_s, t_s, src_s, dst_s = self.__compute_msg__(
            n_id, self.msg_s_store, self.msg_s_module)

        # Compute messages (dst -> src).
        msg_d, t_d, src_d, dst_d = self.__compute_msg__(
            n_id, self.msg_d_store, self.msg_d_module)

        # Aggregate messages.
        idx = torch.cat([src_s, src_d], dim=0)
        msg = torch.cat([msg_s, msg_d], dim=0)
        t = torch.cat([t_s, t_d], dim=0)
        aggr = self.aggr_module(msg, self.__assoc__[idx], t, n_id.size(0))

        # Get local copy of updated memory.
        memory = self.gru(aggr, self.memory[n_id])

        # Get local copy of updated `last_update`.
        dim_size = self.last_update.size(0)
        last_update = scatter_max(t, idx, dim=0, dim_size=dim_size)[0][n_id]

        return memory, last_update
Exemple #30
0
def take_action_deterministic_batch_dqn(target_net, player, batch_instances):

    with torch.no_grad():
        # We compute the target values
        batch = batch_instances.G_torch.batch
        mask_values = batch_instances.J.eq(0)[:, 0]
        action_values = target_net(
            batch_instances.G_torch,
            batch_instances.n_nodes,
            batch_instances.Omegas,
            batch_instances.Phis,
            batch_instances.Lambdas,
            batch_instances.Omegas_norm,
            batch_instances.Phis_norm,
            batch_instances.Lambdas_norm,
            batch_instances.J,
        )
        action_values = action_values[mask_values]
        batch = batch[mask_values]
        # if it's the turn of the attacker
        if player == 1:
            # we take the argmin
            values, actions = scatter_min(action_values, batch, dim=0)
        else:
            # we take the argmax
            values, actions = scatter_max(action_values, batch, dim=0)

    return actions.view(-1).tolist()