Example #1
0
    def __call__(self, data):
        if len(data.x) < self.min_nodes:
            return data

        connectivity = self.to_dense(data).adj
        clustering = AgglomerativeClustering(
            n_clusters=None,
            distance_threshold=self.distance_threshold,
            connectivity=connectivity)
        labels = clustering.fit_predict(data.x)
        labels = torch.from_numpy(labels)

        data.x = scatter_mean(data.x, labels, dim=0)
        data.pos = scatter_mean(data.pos, labels,
                                dim=0) if data.pos is not None else None

        if data.edge_index is not None:
            new_edges = []
            edges = data.edge_index.T
            for edge in edges:
                if labels[edge[0]] != labels[edge[1]]:
                    new_edges.append((labels[edge[0]], labels[edge[1]]))
            data.edge_index = torch.from_numpy(np.unique(edges, axis=0)).T

        return data
Example #2
0
    def __call__(self, img):
        from skimage.segmentation import slic

        img = img.permute(1, 2, 0)
        h, w, c = img.size()

        seg = slic(img.to(torch.double).numpy(), start_label=0, **self.kwargs)
        seg = torch.from_numpy(seg)

        x = scatter_mean(img.view(h * w, c), seg.view(h * w), dim=0)

        pos_y = torch.arange(h, dtype=torch.float)
        pos_y = pos_y.view(-1, 1).repeat(1, w).view(h * w)
        pos_x = torch.arange(w, dtype=torch.float)
        pos_x = pos_x.view(1, -1).repeat(h, 1).view(h * w)

        pos = torch.stack([pos_x, pos_y], dim=-1)
        pos = scatter_mean(pos, seg.view(h * w), dim=0)

        data = Data(x=x, pos=pos)

        if self.add_seg:
            data.seg = seg.view(1, h, w)

        if self.add_img:
            data.img = img.permute(2, 0, 1).view(1, c, h, w)

        return data
    def forward(self, x_1, x_2, edge_index_pos, edge_index_neg):
        """
        Forward propagation pass with features an indices.
        :param x_1: Features for left hand side vertices.
        :param x_2: Features for right hand side vertices.
        :param edge_index_pos: Positive indices.
        :param edge_index_neg: Negative indices.
        :return out: Abstract convolved features.
        """
        edge_index_pos, _ = remove_self_loops(edge_index_pos, None)
        edge_index_pos, _ = add_self_loops(edge_index_pos, num_nodes=x_1.size(0))
        edge_index_neg, _ = remove_self_loops(edge_index_neg, None)
        edge_index_neg, _ = add_self_loops(edge_index_neg, num_nodes=x_2.size(0))

        row_pos, col_pos = edge_index_pos
        row_neg, col_neg = edge_index_neg

        if self.norm:  # pos: [x_1:balance features; x_2:unbalance features]  neg :[x_1:unbalance features; x_2:balance features]
            out_1 = scatter_mean(x_1[col_pos], row_pos, dim=0, dim_size=x_1.size(0))
            out_2 = scatter_mean(x_2[col_neg], row_neg, dim=0, dim_size=x_2.size(0))
        else:
            out_1 = scatter_add(x_1[col_pos], row_pos, dim=0, dim_size=x_1.size(0))
            out_2 = scatter_add(x_2[col_neg], row_neg, dim=0, dim_size=x_2.size(0))

        out = torch.cat((out_1, out_2, x_1), 1)
        out = torch.matmul(out, self.weight)
        if self.bias is not None:
            out = out + self.bias

        if self.norm_embed:
            out = F.normalize(out, p=2, dim=-1)
        return out
    def forward(self, x, edge_index, edge_attr, h, u, batch):
        """ Global Update of Graph Net Layer

            @param x: [N x n_outc], where N is the number of nodes.
            @param edge_index: [2 x E] with max entry N - 1.
            @param edge_attr: [E x e_outc]
            @param h: [B x hidden]
            @param u: [B x u_inc]
            @param batch: [N] with max entry B - 1.

            @return: a [B x u_outc] torch tensor
        """

        row, col = edge_index
        edge_batch = batch[
            row]  # edge_batch is same as batch in EdgeModel.forward(). Shape: [E]

        per_batch_edge_aggregations = scatter_mean(
            edge_attr, edge_batch, dim=0)  # Shape: [B x e_outc]
        per_batch_node_aggregations = scatter_mean(
            x, batch, dim=0)  # Shape: [B x n_outc]

        x = torch.cat(
            [u, per_batch_node_aggregations, per_batch_edge_aggregations],
            dim=1)  # Shape: [B x (u_inc + n_outc + e_outc)]
        x = F.relu(self.in1(self.fc1(x)))
        x = x.unsqueeze(1)
        h = h.unsqueeze(0)
        x, h = self.gru(x, h)
        x = x.squeeze()
        h = x.squeeze()
        x = self.fc2(x)
        return x, h
Example #5
0
    def forward(self, x, edge_index, edge_attr):

        #print('weight : ', torch.sum(self.weight))

        row, col = edge_index
        num_node = len(x)
        edge_attr = edge_attr.unsqueeze(
            -1) if edge_attr.dim() == 1 else edge_attr

        # create edge feature by concatenating node feature
        alpha = torch.cat([x[row], x[col]], dim=-1)

        # multiply the edge features with the fliter
        alpha = torch.mm(alpha, self.weight)

        # multiply each edge features with the corresponding dist
        alpha = edge_attr * alpha

        # scatter the resulting edge feature to get node features
        out = torch.zeros(num_node, self.out_channels).to(alpha.device)
        out = scatter_mean(alpha, row, dim=0, out=out)
        out = scatter_mean(alpha, col, dim=0, out=out)

        # add the bias
        if self.bias is not None:
            out = out + self.bias

        return out
Example #6
0
    def forward(self, data):
        data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
        data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))
        data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr))
        x = data.x
        x_1 = scatter_mean(data.x, data.batch, dim=0)

        data.x = avg_pool(x, data.assignment_2)
        data.x = torch.cat([data.x, data.iso_type_2], dim=1)

        data.x = F.elu(self.conv4(data.x, data.edge_index_2))
        data.x = F.elu(self.conv5(data.x, data.edge_index_2))
        x_2 = scatter_mean(data.x, data.batch_2, dim=0)

        data.x = avg_pool(x, data.assignment_3)
        data.x = torch.cat([data.x, data.iso_type_3], dim=1)

        data.x = F.elu(self.conv6(data.x, data.edge_index_3))
        data.x = F.elu(self.conv7(data.x, data.edge_index_3))
        x_3 = scatter_mean(data.x, data.batch_3, dim=0)

        x = torch.cat([x_1, x_2, x_3], dim=1)

        x = F.elu(self.fc1(x))
        x = F.elu(self.fc2(x))
        x = self.fc3(x)
        return x.view(-1)
    def forward(self, data):
        data.x = F.elu(self.conv1(data.x, data.edge_index))
        data.x = F.elu(self.conv2(data.x, data.edge_index))
        data.x = F.elu(self.conv3(data.x, data.edge_index))
        x = data.x
        x_1 = scatter_add(data.x, data.batch, dim=0)

        data.x = avg_pool(x, data.assignment_index_2)
        data.x = torch.cat([data.x, data.iso_type_2], dim=1)

        data.x = F.elu(self.conv4(data.x, data.edge_index_2))
        data.x = F.elu(self.conv5(data.x, data.edge_index_2))
        x_2 = scatter_mean(data.x, data.batch_2, dim=0)

        data.x = avg_pool(x, data.assignment_index_3)
        data.x = torch.cat([data.x, data.iso_type_3], dim=1)

        data.x = F.elu(self.conv6(data.x, data.edge_index_3))
        data.x = F.elu(self.conv7(data.x, data.edge_index_3))
        x_3 = scatter_mean(data.x, data.batch_3, dim=0)

        x = torch.cat([x_1, x_2, x_3], dim=1)

        x = F.elu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.elu(self.fc2(x))
        x = self.fc3(x)
        if self.args.type == "regression":
            return x.view(-1)
        elif self.args.type == "TU":
            return F.log_softmax(x, dim=1)
        else:
            return self.sigmoid(x).view(-1)
    def forward(self, vertices):
        # This function computes the node values
        # vertices has shape [n_data, in_channels, n_nodes]
        # edge_index has shape [2, E], top indicating the source and bottom indicating the dest
        vertices = vertices.reshape((vertices.shape[0], 1, self.n_nodes, 11))
        v_features = self.x_lin(vertices).squeeze()

        msgs = self.compute_msgs(v_features, len(vertices))
        msgs = msgs.repeat((1, 1, 2))
        new_vertex = torch_scatter.scatter_mean(msgs, self.dest_edges, dim=-1)
        new_vertex = new_vertex[:, None, :, :]

        ##### msg passing
        n_msg_passing = 5
        for i in range(n_msg_passing):
            vertices_after_first_round = self.x_lin_after_first_round(new_vertex).squeeze()
            msgs = self.compute_msgs(vertices_after_first_round, len(vertices))
            msgs = msgs.repeat((1, 1, 2))
            residual = torch_scatter.scatter_mean(msgs, self.dest_edges, dim=-1)
            residual = residual[:, None, :, :]
            new_vertex = new_vertex + residual
        ##### end of msg passing

        # Final round of output
        # new_vertex = new_vertex, 1, new_vertex.shape[1], new_vertex.shape[2]))
        final_vertex_output = self.vertex_output_lin(new_vertex).squeeze()
        graph_output = self.graph_output_lin(final_vertex_output)
        return graph_output
Example #9
0
def get_aggregation(aggregation):
    """
    Factory dictionary for aggregation depending on the hparams["aggregation"]
    """

    aggregation_dict = {
        "sum": lambda e, end, x: scatter_add(e, end, dim=0, dim_size=x.shape[0]),
        "mean": lambda e, end, x: scatter_mean(e, end, dim=0, dim_size=x.shape[0]),
        "max": lambda e, end, x: scatter_max(e, end, dim=0, dim_size=x.shape[0])[0],
        "sum_max": lambda e, end, x: torch.cat(
            [
                scatter_max(e, end, dim=0, dim_size=x.shape[0])[0],
                scatter_add(e, end, dim=0, dim_size=x.shape[0]),
            ],
            dim=-1,
        ),
        "mean_sum": lambda e, end, x: torch.cat(
            [
                scatter_mean(e, end, dim=0, dim_size=x.shape[0]),
                scatter_add(e, end, dim=0, dim_size=x.shape[0]),
            ],
            dim=-1,
        ),
        "mean_max": lambda e, end, x: torch.cat(
            [
                scatter_max(e, end, dim=0, dim_size=x.shape[0])[0],
                scatter_mean(e, end, dim=0, dim_size=x.shape[0]),
            ],
            dim=-1,
        ),
    }

    return aggregation_dict[aggregation]
    def forward(self, data):
        edge_index, x_ids = data.edge_index, data.node_ids
        x = self.embed(x_ids)
        # print(f"X[0]: {x.shape}")

        hidden_reps = [x]
        for i in range(self.num_layers - 1):
            x = self.gins[i](x, edge_index)
            x = self.batch_norms[i](x)
            x = self.relu(x)
            # print(f"X[{i + 1}]: {x.shape}")
            hidden_reps.append(x)

        score_over_layer = 0
        for layer, h in enumerate(hidden_reps):
            score_over_layer += self.dp(self.linears_prediction[layer](h))

        if self.aggr == 'mean':
            score_over_layer = scatter_mean(score_over_layer,
                                            data.batch,
                                            dim=0)
        if self.aggr == 'max':
            score_over_layer, _ = scatter_mean(score_over_layer,
                                               data.batch,
                                               dim=0)
        if self.aggr == 'sum':
            score_over_layer = scatter_mean(score_over_layer,
                                            data.batch,
                                            dim=0)

        return score_over_layer
Example #11
0
    def forward(self, x, edge_index, edge_attr):

        row, col = edge_index
        num_node = len(x)
        edge_attr = edge_attr.unsqueeze(
            -1) if edge_attr.dim() == 1 else edge_attr

        # create edge feature by concatenating node feature
        alpha = torch.cat([x[row], x[col]], dim=-1)

        # multiply the edge features with the fliter
        alpha = torch.mm(alpha, self.weight)

        # multiply each edge features with the corresponding dist
        alpha = edge_attr * alpha

        # scatter the resulting edge feature to get node features
        out = torch.zeros(num_node, self.out_channels).to(alpha.device)
        out = scatter_mean(alpha, row, dim=0, out=out)

        # if the graph is undirected and (i,j) and (j,i) are both in
        # the edge_index then we do not need to have that second line
        # or we count everything twice
        if not self.undirected:
            out = scatter_mean(alpha, col, dim=0, out=out)

        # add the bias
        if self.bias is not None:
            out = out + self.bias

        return out
Example #12
0
    def forward(self, konf, pose, noise):
        robot_curr_pose = pose[:, -4:]
        robot_curr_pose_expanded = robot_curr_pose.unsqueeze(1).repeat(
            (1, 618, 1)).unsqueeze(-1)
        concat = torch.cat([robot_curr_pose_expanded, konf], dim=2)
        concat = concat.reshape((concat.shape[0], concat.shape[-1],
                                 concat.shape[1], concat.shape[2]))

        features = self.first_features(concat)
        features = features.squeeze()
        paired_feature_values = features[:, :, self.edges[1]]
        indices_of_neighboring_nodes = self.edges[0]

        vertex_values = torch_scatter.scatter_mean(
            paired_feature_values, indices_of_neighboring_nodes, dim=-1)

        n_passes = 2
        for _ in range(n_passes):
            vertex_values = vertex_values.view(
                (vertex_values.shape[0], vertex_values.shape[1],
                 vertex_values.shape[2], 1))
            vertex_values = self.features(vertex_values).squeeze()
            paired_feature_values = vertex_values[:, :, self.edges[1]]
            vertex_values = torch_scatter.scatter_mean(
                paired_feature_values, indices_of_neighboring_nodes, dim=-1)

        features = features.view(
            (features.shape[0], features.shape[1], features.shape[2], 1))
        features = torch.nn.MaxPool2d(kernel_size=(2, 1))(features)
        features = torch.nn.MaxPool2d(kernel_size=(2, 1))(features)
        features = features.view(
            (features.shape[0], features.shape[1] * features.shape[2]))
        value = self.value(features)

        return value
    def get_vertex_activations(self, vertices):
        ### First round of vertex feature computation
        # v = np.concatenate([prm_vertices, q0s, qgs, self.collisions[idx]], axis=-1)
        v_features = self.get_vertex_features(vertices)

        collisions = vertices[:, :, 6:]
        collisions = collisions.permute((0, 2, 1))
        v_features = torch.cat((v_features, collisions), 1)
        ##############
        msgs = self.compute_msgs(v_features, len(vertices))
        msgs = msgs.repeat((1, 1, 2))
        new_vertex = torch_scatter.scatter_mean(msgs, self.dest_edges, dim=-1)
        new_vertex = new_vertex[:, None, :, :]

        ##### msg passing
        for i in range(self.n_msg_passing):
            vertices_after_first_round = self.x_lin_after_first_round(
                new_vertex).squeeze(dim=2)
            vertices_after_first_round = torch.cat(
                (vertices_after_first_round, collisions), 1)
            msgs = self.compute_msgs(vertices_after_first_round, len(vertices))
            msgs = msgs.repeat((1, 1, 2))
            residual = torch_scatter.scatter_mean(msgs,
                                                  self.dest_edges,
                                                  dim=-1)
            residual = residual[:, None, :, :]
            new_vertex = new_vertex + residual
        ##### end of msg passing

        final_vertex_output = self.vertex_output_lin(new_vertex).squeeze()
        n_data = len(vertices)
        if n_data == 1:
            final_vertex_output = final_vertex_output[None, :]
        return final_vertex_output
    def forward(self, data):
        data.x = F.elu(self.conv1(data.x, data.edge_index))
        data.x = F.elu(self.conv2(data.x, data.edge_index))
        data.x = F.elu(self.conv3(data.x, data.edge_index))
        x = data.x
        x_1 = scatter_mean(data.x, data.batch, dim=0)

        data.x = avg_pool(x, data.assignment_index_2)
        data.x = torch.cat([data.x, data.iso_type_2], dim=1)

        data.x = F.elu(self.conv4(data.x, data.edge_index_2))
        data.x = F.elu(self.conv5(data.x, data.edge_index_2))
        x_2 = scatter_mean(data.x, data.batch_2, dim=0)

        data.x = avg_pool(x, data.assignment_index_3)
        data.x = torch.cat([data.x, data.iso_type_3], dim=1)

        data.x = F.elu(self.conv6(data.x, data.edge_index_3))
        data.x = F.elu(self.conv7(data.x, data.edge_index_3))
        x_3 = scatter_mean(data.x, data.batch_3, dim=0)

        x = torch.cat([x_1, x_2, x_3], dim=1)

        if args.no_train:
            x = x.detach()

        x = F.elu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.elu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)
    def forward(self, x, edge_index, edge_attr, u, batch):
        """ Global Update of Graph Net Layer

            @param x: [N x n_outc], where N is the number of nodes.
            @param edge_index: [2 x E] with max entry N - 1.
            @param edge_attr: [E x e_outc]
            @param u: [B x u_inc]
            @param batch: [N] with max entry B - 1.

            @return: a [B x u_outc] torch tensor
        """

        row, col = edge_index
        edge_batch = batch[
            row]  # edge_batch is same as batch in EdgeModel.forward(). Shape: [E]

        per_batch_edge_aggregations = scatter_mean(
            edge_attr, edge_batch, dim=0)  # Shape: [B x e_outc]
        per_batch_node_aggregations = scatter_mean(
            x, batch, dim=0)  # Shape: [B x n_outc]

        out = torch.cat(
            [u, per_batch_node_aggregations, per_batch_edge_aggregations],
            dim=1)  # Shape: [B x (u_inc + n_outc + e_outc)]
        return self.global_mlp(out)
Example #16
0
    def explain_single_layer(self, to_explain, index=None, name=None):
        # todo: deal with special case when previous layer has not been explained

        # preparing variables required for computing LRP
        layer = self.model.get_layer(index=index, name=name)
        rule = self.model.get_rule(index=index, layer_name=name)
        if rule == "z":
            rule = LRP.z_rule
        elif rule == "eps":
            rule = LRP.eps_rule
        else:  # default to use epsilon rule if provided rule name not supported
            rule = LRP.eps_rule

        if name is None:
            name = self.model.index2name(index)
        if index is None:
            index = self.model.name2index(name)

        input = to_explain['A'][name]

        R = to_explain["R"][index + 1]
        if name in self.model.special_layers:
            n_tracks = to_explain["inputs"]["x"].shape[0]
            row, col = to_explain["inputs"]["edge_index"]

            if "node_mlp_2.3" in name:
                R = R.repeat(n_tracks, 1) / n_tracks
            elif "node_mlp_1.3" in name:
                r_x, r_ = R[:, :48], R[:, 48:]
                R = r_[col] / (n_tracks - 1)
                to_explain["R"]["r_x"] = r_x
            elif "edge_mlp.3" in name:
                r_x_row, r_ = R[:, :48], R[:, 48:]
                R = r_
                to_explain["R"]["r_x_row"] = r_x_row
            elif "bn" in name:
                r_src, r_dest = R[:, :48], R[:, 48:]
                to_explain["R"]['r_src'] = r_src
                to_explain["R"]['r_dest'] = r_dest

                # aggregate
                r_x_src = scatter_mean(r_src, row, dim=0, dim_size=n_tracks)
                r_x_dest = scatter_mean(r_dest, col, dim=0, dim_size=n_tracks)

                r_x = to_explain['R']['r_x']
                r_x_row = to_explain['R']['r_x_row']

                R = (r_x_src + r_x_dest + r_x +
                     scatter_mean(r_x_row, row, dim=0, dim_size=n_tracks) +
                     1e-10)
            else:
                pass

        # backward pass with specified LRP rule
        # print(name)
        R = rule(layer, input, R)

        # store result
        to_explain["R"][index] = R
    def forward(self, data):

        # print("---------------------------------------")
        # print(data.x.shape, data.BU_edge_index.shape)
        # print("---------------------------------------")
        """
        x, edge_index = data.x, data.BU_edge_index
        x1 = copy.copy(x.float())
        x = self.conv1(x, edge_index)
        x2 = copy.copy(x)

        rootindex = data.rootindex
        root_extend = th.zeros(len(data.batch), x1.size(1)).to(device)
        batch_size = max(data.batch) + 1
        for num_batch in range(batch_size):
            index = (th.eq(data.batch, num_batch))
            root_extend[index] = x1[rootindex[num_batch]]
        x = th.cat((x, root_extend), 1)

        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        root_extend = th.zeros(len(data.batch), x2.size(1)).to(device)
        for num_batch in range(batch_size):
            index = (th.eq(data.batch, num_batch))
            root_extend[index] = x2[rootindex[num_batch]]
        x = th.cat((x, root_extend), 1)

        x = scatter_mean(x, data.batch, dim=0)
        """

        x, edge_index = data.x, data.BU_edge_index

        x1 = copy.copy(x.float())
        rootindex = data.rootindex
        root_extend = th.zeros(len(data.batch), x1.size(1)).to(device)
        batch_size = max(data.batch) + 1
        for num_batch in range(batch_size):
            index = (th.eq(data.batch, num_batch))
            root_extend[index] = x1[rootindex[num_batch]]
        # x = th.cat((x, root_extend), 1)

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        # x = scatter_mean(x, data.batch, dim=0)
        # print(x.shape, root_extend.shape)

        # HERE
        # x = th.cat((x, root_extend), 1)
        x = scatter_mean(x, data.batch, dim=0)  # GCN - mean

        root_extend = scatter_mean(root_extend, data.batch, dim=0)
        # print(x.shape, root_extend.shape)

        return x, root_extend
Example #18
0
 def forward(self, x_h, x_g, edge_index, edge_attr, u, batch_h, batch_g):
     out = torch.cat([
         u,
         scatter_mean(x_h, batch_h, dim=0),
         scatter_mean(x_g, batch_g, dim=0)
     ],
                     dim=1)
     return self.global_mlp(out)
Example #19
0
    def multi_hop(self, triple_prob, distance, head, tail, concept_label, triple_label, gamma=0.8, iteration=3,
                       method="avg"):
        '''
        triple_prob: bsz x L x mem_t
        distance: bsz x mem
        head, tail: bsz x mem_t
        concept_label: bsz x mem
        triple_label: bsz x mem_t

        Init binary vector with source concept == 1 and others 0
        expand to size: bsz x L x mem
        '''
        concept_probs = []

        cpt_size = (triple_prob.size(0), triple_prob.size(1), distance.size(1))
        init_mask = torch.zeros_like(distance).unsqueeze(1).expand(*cpt_size).to(distance.device).float()
        init_mask.masked_fill_((distance == 0).unsqueeze(1), 1)
        final_mask = init_mask.clone()

        init_mask.masked_fill_((concept_label == -1).unsqueeze(1), 0)
        concept_probs.append(init_mask)

        head = head.unsqueeze(1).expand(triple_prob.size(0), triple_prob.size(1), -1)
        tail = tail.unsqueeze(1).expand(triple_prob.size(0), triple_prob.size(1), -1)

        for step in range(iteration):
            '''
            Calculate triple head score
            '''
            node_score = concept_probs[-1]
            triple_head_score = node_score.gather(2, head)
            triple_head_score.masked_fill_((triple_label == -1).unsqueeze(1), 0)
            '''
            Method: 
                - avg:
                    s(v) = Avg_{u \in N(v)} gamma * s(u) + R(u->v) 
                - max: 
                    s(v) = max_{u \in N(v)} gamma * s(u) + R(u->v)
            '''
            update_value = triple_head_score * gamma + triple_prob
            out = torch.zeros_like(node_score).to(node_score.device).float()
            if method == "max":
                scatter_max(update_value, tail, dim=-1, out=out)
            elif method == "avg":
                scatter_mean(update_value, tail, dim=-1, out=out)
            out.masked_fill_((concept_label == -1).unsqueeze(1), 0)

            concept_probs.append(out)

        '''
        Natural decay of concept that is multi-hop away from source
        '''
        total_concept_prob = final_mask * -1e5
        for prob in concept_probs[1:]:
            total_concept_prob += prob
        # bsz x L x mem

        return total_concept_prob
Example #20
0
 def forward(self, x, edge_index, edge_attr, u, batch):
     """
     """
     row, col = edge_index
     e_batch = batch[row]
     e_agg = scatter_mean(edge_attr, e_batch, dim=0)
     x_agg = scatter_mean(x, batch, dim=0)
     out = torch.cat([x_agg, e_agg, u], 1)
     return self.phi_u(out)
Example #21
0
 def forward(self, x, edge_index, edge_attr, u, batch):
     row, col = edge_index
     edge_info = scatter_mean(edge_attr,
                              batch[row],
                              dim=0,
                              dim_size=u.size(0))
     node_info = scatter_mean(x, batch, dim=0, dim_size=u.size(0))
     out = torch.cat([u, node_info, edge_info], dim=1)
     return self.global_mlp(out)
Example #22
0
    def forward(self, data):
        f1 = F.relu(self.conv1(data.adj, data.input))
        f1 = F.relu(self.conv12(data.adj, f1))
        batch = Variable(data.batch.view(-1, 1).expand(data.batch.size(0), 64))
        f1_res = scatter_mean(batch, f1)

        size, start = self.pool_args(32, 8)
        data2, _ = sparse_voxel_max_pool(
            data, size, start, transform, weight=f1[:, 0])

        f2 = F.relu(self.conv2(data2.adj, data2.input))
        f2 = F.relu(self.conv22(data2.adj, f2))
        batch = Variable(
            data2.batch.view(-1, 1).expand(data2.batch.size(0), 64))
        f2_res = scatter_mean(batch, f2)

        size, start = self.pool_args(16, 4)
        data3, _ = sparse_voxel_max_pool(
            data2, size, start, transform, weight=f2[:, 0])

        f3 = F.relu(self.conv3(data3.adj, data3.input))
        f3 = F.relu(self.conv32(data3.adj, f3))
        batch = Variable(
            data3.batch.view(-1, 1).expand(data3.batch.size(0), 64))
        f3_res = scatter_mean(batch, f3)

        size, start = self.pool_args(8, 2)
        data4, _ = sparse_voxel_max_pool(
            data3, size, start, transform, weight=f3[:, 0])

        f4 = F.relu(self.conv4(data4.adj, data4.input))
        f4 = F.relu(self.conv42(data4.adj, f4))
        batch = Variable(
            data4.batch.view(-1, 1).expand(data4.batch.size(0), 64))
        f4_res = scatter_mean(batch, f4)

        size, start = self.pool_args(4, 1)
        data5, _ = sparse_voxel_max_pool(
            data4, size, start, transform, weight=f4[:, 0])

        f5 = F.relu(self.conv5(data5.adj, data5.input))
        f5 = F.relu(self.conv52(data5.adj, f5))
        batch = Variable(
            data5.batch.view(-1, 1).expand(data5.batch.size(0), 64))
        f5_res = scatter_mean(batch, f5)

        x = torch.cat([f1_res, f2_res, f3_res, f4_res, f5_res], dim=1)

        # data.input = F.relu(self.conv5(data.adj, data.input))
        # data, _ = dense_voxel_max_pool(data, 1, -0.5, 1.5)

        x = x.view(-1, self.fc1.weight.size(1))
        x = F.dropout(x, training=self.training)
        x = self.fc1(x)

        return F.log_softmax(x, dim=1)
Example #23
0
 def global_model(self, x, edge_index, edge_attr, u, batch):
     # x: [N, F_x], where N is the number of nodes.
     # edge_index: [2, E] with max entry N - 1.
     # edge_attr: [E, F_e]
     # u: [B, F_u]
     # batch: [N] with max entry B - 1.
     row, _ = edge_index
     edge_mean = scatter_mean(edge_attr, batch[row], dim=0)
     out = torch.cat([u, scatter_mean(x, batch, dim=0), edge_mean], dim=1)
     out = self.global_msg(out)
     return out
Example #24
0
 def forward(self, x, edge_index, edge_attr, u, batch):
     """
     """
     row, col = edge_index
     # compute the batch index for all edges
     e_batch = batch[row]
     # aggregate all edges in the graph
     e_agg = scatter_mean(edge_attr, e_batch, dim=0)
     # aggregate all nodes in the graph
     x_agg = scatter_mean(x, batch, dim=0)
     out = torch.cat([x_agg, e_agg, u], 1)
     return self.phi_u(out)
Example #25
0
 def global_model(x, edge_attr, u, v_indices, e_indices):
     if self.independent:
         return self.global_mlp(u)
     out = torch.cat(
         [
             u,
             scatter_mean(x, v_indices, dim=0),
             scatter_mean(edge_attr, e_indices, dim=0),
         ],
         dim=1,
     )
     return self.global_mlp(out)
    def forward(self, x: Tensor, batch: Optional[Tensor] = None) -> Tensor:
        """"""
        if batch is None:
            batch = x.new_zeros(x.size(0), dtype=torch.long)

        batch_size = int(batch.max()) + 1

        mean = scatter_mean(x, batch, dim=0, dim_size=batch_size)
        out = x - mean.index_select(0, batch) * self.mean_scale
        var = scatter_mean(out.pow(2), batch, dim=0, dim_size=batch_size)
        std = (var + self.eps).sqrt().index_select(0, batch)
        return self.weight * out / std + self.bias
Example #27
0
    def forward(self, data):
        """Given a :obj:`data` batch, computes the forward pass.

        Args:
            data (torch_geometric.data.Data): The input data, holding subject
                :obj:`sub`, relation :obj:`rel` and object :obj:`obj`
                information with shape :obj:`[batch_size]`.
                In addition, :obj:`data` needs to hold history information for
                subjects, given by a vector of node indices :obj:`h_sub` and
                their relative timestamps :obj:`h_sub_t` and batch assignments
                :obj:`h_sub_batch`.
                The same information must be given for objects (:obj:`h_obj`,
                :obj:`h_obj_t`, :obj:`h_obj_batch`).
        """

        assert 'h_sub_batch' in data and 'h_obj_batch' in data
        batch_size, seq_len = data.sub.size(0), self.seq_len

        h_sub_t = data.h_sub_t + data.h_sub_batch * seq_len
        h_obj_t = data.h_obj_t + data.h_obj_batch * seq_len

        h_sub = scatter_mean(self.ent[data.h_sub],
                             h_sub_t,
                             dim=0,
                             dim_size=batch_size * seq_len).view(
                                 batch_size, seq_len, -1)
        h_obj = scatter_mean(self.ent[data.h_obj],
                             h_obj_t,
                             dim=0,
                             dim_size=batch_size * seq_len).view(
                                 batch_size, seq_len, -1)

        sub = self.ent[data.sub].unsqueeze(1).repeat(1, seq_len, 1)
        rel = self.rel[data.rel].unsqueeze(1).repeat(1, seq_len, 1)
        obj = self.ent[data.obj].unsqueeze(1).repeat(1, seq_len, 1)

        _, h_sub = self.sub_gru(torch.cat([sub, h_sub, rel], dim=-1))
        _, h_obj = self.obj_gru(torch.cat([obj, h_obj, rel], dim=-1))
        h_sub, h_obj = h_sub.squeeze(0), h_obj.squeeze(0)

        h_sub = torch.cat([self.ent[data.sub], h_sub, self.rel[data.rel]],
                          dim=-1)
        h_obj = torch.cat([self.ent[data.obj], h_obj, self.rel[data.rel]],
                          dim=-1)

        h_sub = F.dropout(h_sub, p=self.dropout, training=self.training)
        h_obj = F.dropout(h_obj, p=self.dropout, training=self.training)

        log_prob_obj = F.log_softmax(self.sub_lin(h_sub), dim=1)
        log_prob_sub = F.log_softmax(self.obj_lin(h_obj), dim=1)

        return log_prob_obj, log_prob_sub
Example #28
0
    def forward(self, batch_data):

        if self.graph_level_feature:  ### Use rdkit_2d_normalized_features as input graph-level feature
            x, edge_index, edge_attr = batch_data.x, batch_data.edge_index, batch_data.edge_attr
            row, col = edge_index
            u = scatter_mean(edge_attr,
                             batch_data.batch[row],
                             dim=0,
                             dim_size=max(batch_data.batch) + 1)
            aug_feat = batch_data.graph_attr
            if len(aug_feat.shape) != 2:
                aug_feat = torch.reshape(aug_feat,
                                         (-1, self.num_global_features))
        else:
            x, edge_index, edge_attr = batch_data.x, batch_data.edge_index, batch_data.edge_attr
            row, col = edge_index
            u = scatter_mean(edge_attr,
                             batch_data.batch[row],
                             dim=0,
                             dim_size=max(batch_data.batch) + 1)

        x = self.mlp_node(x)
        edge_attr = self.mlp_edge(edge_attr)
        u = self.mlp_global(u)

        row, col = edge_index

        ori_batch = batch_data.batch

        x = self.norm_node[-1](x, ori_batch)
        edge_attr = self.norm_edge[-1](edge_attr, ori_batch[row])
        x = self.bn_node[-1](x)
        edge_attr = self.bn_edge[-1](edge_attr)
        u = self.bn_global[-1](u)

        for i in range(self.depth):

            x, edge_attr, u = self.gn[i](x, edge_index, edge_attr, u,
                                         ori_batch)

            x = self.norm_node[i](x, batch_data.batch)
            edge_attr = self.norm_edge[i](edge_attr, batch_data.batch[row])
            x = self.bn_node[i](x)
            edge_attr = self.bn_edge[i](edge_attr)
            u = self.bn_global[i](u)

        if self.graph_level_feature:
            u = torch.cat([u, aug_feat], dim=1)
        out = self.mlp1(u)

        return out
Example #29
0
    def message_step(self, x, start, end, e):

        # Compute new node features
        if self.hparams["aggregation"] == "sum":
            edge_messages = scatter_add(e, end, dim=0, dim_size=x.shape[0])

        elif self.hparams["aggregation"] == "mean":
            edge_messages = scatter_mean(e, end, dim=0, dim_size=x.shape[0])

        elif self.hparams["aggregation"] == "max":
            edge_messages = scatter_max(e, end, dim=0, dim_size=x.shape[0])[0]

        elif self.hparams["aggregation"] == "sum_max":
            edge_messages = torch.cat(
                [
                    scatter_max(e, end, dim=0, dim_size=x.shape[0])[0],
                    scatter_add(e, end, dim=0, dim_size=x.shape[0]),
                ],
                dim=-1,
            )
        elif self.hparams["aggregation"] == "mean_sum":
            edge_messages = torch.cat(
                [
                    scatter_mean(e, end, dim=0, dim_size=x.shape[0]),
                    scatter_add(e, end, dim=0, dim_size=x.shape[0]),
                ],
                dim=-1,
            )
        elif self.hparams["aggregation"] == "mean_max":
            edge_messages = torch.cat(
                [
                    scatter_max(e, end, dim=0, dim_size=x.shape[0])[0],
                    scatter_mean(e, end, dim=0, dim_size=x.shape[0]),
                ],
                dim=-1,
            )

        node_inputs = torch.cat([x, edge_messages], dim=-1)

        x_out = self.node_network(node_inputs)

        x_out += x

        # Compute new edge features
        edge_inputs = torch.cat([x_out[start], x_out[end], e], dim=-1)
        e_out = self.edge_network(edge_inputs)

        e_out += e

        return x_out, e_out
Example #30
0
 def forward(self, x, edge_index):
     row, col = edge_index
     mean = scatter_mean(x[col], row, dim=0, dim_size=x.size(0))
     mean = torch.mean(mean, dim=-1, keepdim=True)
     var = scatter_mean((x[col] - mean[row])**2,
                        row,
                        dim=0,
                        dim_size=x.size(0))
     var = torch.mean(var, dim=-1, keepdim=True)
     # std = scatter_std(x[col], row, dim=0, dim_size=x.size(0))
     out = (x[col] - mean[row]) / (var[row] + self.eps).sqrt()
     # out = (x[col] - mean[row]) / (std[row]**2 + self.eps).sqrt()
     out = self.gamma * out + self.beta
     return out