Exemplo n.º 1
0
    def forward(self, data):
        batch_obj = Batch()
        x, pos, batch = data.x, data.pos, data.batch
        if self._precompute_multi_scale:
            idx = getattr(data, "idx_{}".format(self._index), None)
        else:
            idx = self.sampler(pos, batch)
            batch_obj.idx = idx

        ms_x = []
        for scale_idx in range(self.neighbour_finder.num_scales):
            if self._precompute_multi_scale:
                edge_index = getattr(
                    data, "edge_index_{}_{}".format(self._index, scale_idx),
                    None)
            else:
                row, col = self.neighbour_finder(
                    pos,
                    pos[idx],
                    batch_x=batch,
                    batch_y=batch[idx],
                    scale_idx=scale_idx,
                )
                edge_index = torch.stack([col, row], dim=0)

            ms_x.append(self.conv(x, (pos, pos[idx]), edge_index, batch))

        batch_obj.x = torch.cat(ms_x, -1)
        batch_obj.pos = pos[idx]
        batch_obj.batch = batch[idx]
        copy_from_to(data, batch_obj)
        return batch_obj
    def forward(self, data):
        batch_obj = Batch()
        x, pos, batch = data.x, data.pos, data.batch
        idx_sampler = self.sampler(pos=pos, x=x, batch=batch)

        idx_neighbour, _ = self.neighbour_finder(pos,
                                                 pos,
                                                 batch_x=batch,
                                                 batch_y=batch)

        shadow_x = torch.full((1, ) + x.shape[1:],
                              self.shadow_features_fill).to(x.device)
        shadow_points = torch.full((1, ) + pos.shape[1:],
                                   self.shadow_points_fill_).to(x.device)

        x = torch.cat([x, shadow_x], dim=0)
        pos = torch.cat([pos, shadow_points], dim=0)

        x_neighbour = x[idx_neighbour]
        pos_centered_neighbour = pos[idx_neighbour] - pos[:-1].unsqueeze(
            1)  # Centered the points

        batch_obj.x = self.conv(x, pos, x_neighbour, pos_centered_neighbour,
                                idx_neighbour, idx_sampler)

        batch_obj.pos = pos[idx_sampler]
        batch_obj.batch = batch[idx_sampler]
        copy_from_to(data, batch_obj)
        return batch_obj
 def forward(self, data, **kwargs):
     batch_obj = Batch()
     x, pos, batch = data.x, data.pos, data.batch
     x = self.nn(torch.cat([x, pos], dim=1))
     x = self.pool(x, batch)
     batch_obj.x = x
     batch_obj.pos = pos.new_zeros((x.size(0), 3))
     batch_obj.batch = torch.arange(x.size(0), device=batch.device)
     copy_from_to(data, batch_obj)
     return batch_obj
Exemplo n.º 4
0
    def forward(self, data, **kwargs):
        batch_obj = Batch()
        x, pos, batch = data.x, data.pos, data.batch
        idx = self.sampler(pos, batch)
        row, col = self.neighbour_finder(pos, pos[idx], batch_x=batch, batch_y=batch[idx])
        edge_index = torch.stack([col, row], dim=0)
        batch_obj.idx = idx
        batch_obj.edge_index = edge_index

        batch_obj.x = self.conv(x, (pos[idx], pos), edge_index, batch)

        batch_obj.pos = pos[idx]
        batch_obj.batch = batch[idx]
        copy_from_to(data, batch_obj)
        return batch_obj
Exemplo n.º 5
0
    def forward(self, data, **kwargs):
        batch_obj = Batch()
        x, pos, batch = data.x, data.pos, data.batch
        idx = self.sampler(pos, batch)
        batch_obj.idx = idx

        ms_x = []
        for scale_idx in range(self.neighbour_finder.num_scales):
            row, col = self.neighbour_finder(pos, pos[idx], batch_x=batch, batch_y=batch[idx], scale_idx=scale_idx,)
            edge_index = torch.stack([col, row], dim=0)

            ms_x.append(self.conv(x, (pos, pos[idx]), edge_index, batch))

        batch_obj.x = torch.cat(ms_x, -1)
        batch_obj.pos = pos[idx]
        batch_obj.batch = batch[idx]
        copy_from_to(data, batch_obj)
        return batch_obj
Exemplo n.º 6
0
 def forward(self, data, **kwargs):
     batch_obj = Batch()
     x = data.x  # (N, indim)
     shortcut = x  # (N, indim)
     x = self.features_downsample_nn(x)  # (N, outdim//4)
     # if this is an identity resnet block, idx will be None
     data = self.convs(data)  # (N', convdim)
     x = data.x
     idx = data.idx
     x = self.features_upsample_nn(x)  # (N', outdim)
     if idx is not None:
         shortcut = shortcut[idx]  # (N', indim)
     shortcut = self.shortcut_feature_resize_nn(shortcut)  # (N', outdim)
     x = shortcut + x
     batch_obj.x = x
     batch_obj.pos = data.pos
     batch_obj.batch = data.batch
     copy_from_to(data, batch_obj)
     return batch_obj
Exemplo n.º 7
0
    def forward(self, data):
        batch_obj = Batch()
        x, pos, batch = data.x, data.pos, data.batch
        if self._precompute_multi_scale:
            idx = getattr(data, "index_{}".format(self._index), None)
            edge_index = getattr(data, "edge_index_{}".format(self._index),
                                 None)
        else:
            idx = self.sampler(pos, batch)
            row, col = self.neighbour_finder(pos,
                                             pos[idx],
                                             batch_x=batch,
                                             batch_y=batch[idx])
            edge_index = torch.stack([col, row], dim=0)
            batch_obj.idx = idx
            batch_obj.edge_index = edge_index

        batch_obj.x = self.conv(x, (pos, pos[idx]), edge_index, batch)

        batch_obj.pos = pos[idx]
        batch_obj.batch = batch[idx]
        copy_from_to(data, batch_obj)
        return batch_obj