Example #1
0
 def forward(self,
             inputs: ElementsToSummaryRepresentationInput) -> torch.Tensor:
     return scatter(
         src=inputs.element_embeddings,
         index=inputs.element_to_sample_map,
         dim=0,
         dim_size=inputs.num_samples,
         reduce=self.__summarization_type,
     )
Example #2
0
def find_rate(edge_index):
    E = edge_index.size(1)
    src = torch.ones(E).to(edge_index.device)
    deg_hist = scatter(src, edge_index[1], reduce ='sum')
    min_deg = deg_hist.min()
    if min_deg == 0:
        return 1 / deg_hist.max()
    else:
        return min_deg / deg_hist.max()
Example #3
0
    def aggregate(self, inputs: Tensor, index: Tensor) -> Tensor:

        # Step 4: Sum by vertex or by edge
        dim_size = int(index.max()) + 1
        return scatter(inputs,
                       index,
                       dim=0,
                       dim_size=dim_size,
                       reduce=self.aggr)
Example #4
0
    def forward(self, z, pos, batch=None):
        assert z.dim() == 1 and z.dtype == torch.long
        batch = torch.zeros_like(z) if batch is None else batch
        h = self.embedding(z)
        edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
        row, col = edge_index
        edge_weight = (pos[row] - pos[col]).norm(dim=-1)
        e = self.distance_expansion(edge_weight)
        h0 = h.clone()
        s_t = None
        for i in range(self.num_interactions):
            e = self.edge_updates[i](h, edge_index, e)
            msg = self.msg_passes[i](h, edge_index, e)
            if self.hypernet_update:
                s_t = self.state_transitions[i](h0, h, msg)
            else:
                s_t = self.state_transitions[i](msg)
            h = h + s_t

        h = self.fc1(h)
        h = self.act(h)
        h = self.fc2(h)

        if self.dipole:
            # Get center of mass.
            mass = self.atomic_mass[z].view(-1, 1)
            c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0)
            h = h * (pos - c[batch])

        if not self.dipole and self.mean is not None and self.std is not None:
            h = h * self.std + self.mean

        if not self.dipole and self.atomref is not None:
            h = h + self.atomref(z)

        out = scatter(h, batch, dim=0, reduce=self.readout)

        if self.dipole:
            out = torch.norm(out, dim=-1, keepdim=True)

        if self.scale is not None:
            out = self.scale * out

        return out
Example #5
0
def correctness(dataset):
    group, name = dataset
    mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
    rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
    row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
    dim_size = rowptr.size(0) - 1

    for size in sizes:
        try:
            x = torch.randn((row.size(0), size), device=args.device)
            x = x.squeeze(-1) if size == 1 else x

            out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='add')
            out2 = segment_coo(x, row, dim_size=dim_size, reduce='add')
            out3 = segment_csr(x, rowptr, reduce='add')

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)

            out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='mean')
            out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean')
            out3 = segment_csr(x, rowptr, reduce='mean')

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)

            out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='min')
            out2 = segment_coo(x, row, reduce='min')
            out3 = segment_csr(x, rowptr, reduce='min')

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)

            out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='max')
            out2 = segment_coo(x, row, reduce='max')
            out3 = segment_csr(x, rowptr, reduce='max')

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)

        except RuntimeError as e:
            if 'out of memory' not in str(e):
                raise RuntimeError(e)
            torch.cuda.empty_cache()
Example #6
0
        def get_energy(batch, atomref):
            if batch.y is None:
                raise MissingEnergyException()

            if atomref is None:
                return batch.y.clone()

            # remove atomref energies from the target energy
            atomref_energy = scatter(atomref[batch.z], batch.batch, dim=0)
            return (batch.y.squeeze() - atomref_energy.squeeze()).clone()
Example #7
0
def extract_node_feature(data, reduce='add'):
    if reduce in ['mean', 'max', 'add']:
        data.x = scatter(data.edge_attr,
                         data.edge_index[0],
                         dim=0,
                         dim_size=data.num_nodes,
                         reduce=reduce)
    else:
        raise Exception('Unknown Aggregation Type')
    return data
Example #8
0
    def forward_pyg(self, x, adj):
        row, col = adj.coalesce().indices()
        A = x[row]
        B = x[col]

        sim = self.beta * cosine_similarity(A, B)
        P = softmax(sim, row)
        src = x[row] * P.view(-1, 1)
        out = scatter(src, col, dim=0, reduce="add")
        return out
Example #9
0
    def aggregate(self, inputs, index, dim_size):  # pragma: no cover
        r"""Aggregates messages from neighbors as
        :math:`\square_{j \in \mathcal{N}(i)}`.

        By default, delegates call to scatter functions that support
        "add", "mean" and "max" operations specified in :meth:`__init__` by
        the :obj:`aggr` argument.
        """

        return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
Example #10
0
 def _mix_into_atoms(self, x: torch.Tensor, x_paths: PathActivationCollection, atom_to_path_map) -> torch.Tensor:
     x_out = x
     for key in self.path_keys:
         # Gather data for that path length
         row, col = atom_to_path_map[key]
         lin_k = self.path_to_atom_linear[key]
         x_path_k = x_paths[key]
         x_sc_k = scatter(x_path_k[col], row, dim=0, dim_size=x.size(0), reduce='mean')
         x_out = x_out + F.relu(lin_k(x_sc_k))
     return x_out
Example #11
0
 def aggregate(self,
               inputs: Tensor,
               index: Tensor,
               dim_size: Optional[int] = None) -> Tensor:
     out_mean = scatter(inputs,
                        index,
                        dim=self.node_dim,
                        dim_size=dim_size,
                        reduce="sum")
     return out_mean
Example #12
0
 def forward(self, x, rbf, i, num_nodes=None):
     x = self.lin_rbf(rbf) * x
     x = scatter(x, i, dim=0, dim_size=num_nodes
                 )  # x = tf.math.unsorted_segment_sum(x, idnb_i, n_atoms)
     x = self.xupproj(x)
     # this imply that the lin have to change to
     for lin in self.lins:
         x = self.act(lin(x))
     # return final dense layer done
     return self.final_lin(x)
Example #13
0
    def forward(self, data):
        cluster = nn_geometric.voxel_grid(
            data.pos,
            data.batch,
            self.pool_rad,
            start=data.pos.min(dim=0)[0] - self.pool_rad * 0.5,
            end=data.pos.max(dim=0)[0] + self.pool_rad * 0.5)

        cluster, perm = consecutive_cluster(cluster)

        data.x = scatter(data.x, cluster, dim=0, reduce=self.aggr)
        data.pos = scatter(data.pos, cluster, dim=0, reduce='mean')

        data.batch = data.batch[perm]

        data.edge_attr = None
        data.edge_index = None

        return data
    def aggregate(self, inputs, index, ptr=None, dim_size=None):
        if ptr is not None:
            for _ in range(self.node_dim):
                ptr = ptr.unsqueeze(0)
            aggr_mean = segment_csr(inputs, ptr, reduce='mean')
            aggr_max = segment_csr(inputs, ptr, reduce='max')
        else:
            aggr_mean = scatter(inputs,
                                index,
                                dim=self.node_dim,
                                dim_size=dim_size,
                                reduce='mean')
            aggr_max = scatter(inputs,
                               index,
                               dim=self.node_dim,
                               dim_size=dim_size,
                               reduce='max')

        return torch.cat([aggr_mean, aggr_max], dim=-1)
Example #15
0
    def forward(self,
                x: Tensor,
                batch: Optional[Tensor] = None,
                dim_size: Optional[int] = None) -> Tensor:
        """"""
        if batch is None:
            return x - x.mean(dim=0, keepdim=True)

        mean = scatter(x, batch, dim=0, dim_size=dim_size, reduce='mean')
        return x - mean[batch]
Example #16
0
    def forward(self, z, pos, batch=None):
        """"""
        assert z.dim() == 1 and z.dtype == torch.long
        batch = torch.zeros_like(z) if batch is None else batch

        h = self.embedding(z)

        edge_index = radius_graph(pos,
                                  r=self.cutoff,
                                  batch=batch,
                                  max_num_neighbors=self.max_num_neighbors)
        row, col = edge_index
        edge_weight = (pos[row] - pos[col]).norm(dim=-1)
        edge_attr = self.distance_expansion(edge_weight)

        for interaction in self.interactions:
            h = h + interaction(h, edge_index, edge_weight, edge_attr)

        h = self.lin1(h)
        h = self.act(h)
        h = self.lin2(h)

        if self.dipole:
            # Get center of mass.
            mass = self.atomic_mass[z].view(-1, 1)
            c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0)
            h = h * (pos - c[batch])

        if not self.dipole and self.mean is not None and self.std is not None:
            h = h * self.std + self.mean

        if not self.dipole and self.atomref is not None:
            h = h + self.atomref(z)

        out = scatter(h, batch, dim=0, reduce=self.readout)

        if self.dipole:
            out = torch.norm(out, dim=-1, keepdim=True)

        if self.scale is not None:
            out = self.scale * out

        return out
Example #17
0
def softmax(
    src: Tensor,
    index: Optional[Tensor] = None,
    ptr: Optional[Tensor] = None,
    num_nodes: Optional[int] = None,
    dim: int = 0,
) -> Tensor:
    r"""Computes a sparsely evaluated softmax.
    Given a value tensor :attr:`src`, this function first groups the values
    along the first dimension based on the indices specified in :attr:`index`,
    and then proceeds to compute the softmax individually for each group.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor, optional): The indices of elements for applying the
            softmax. (default: :obj:`None`)
        ptr (LongTensor, optional): If given, computes the softmax based on
            sorted inputs in CSR representation. (default: :obj:`None`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
        dim (int, optional): The dimension in which to normalize.
            (default: :obj:`0`)

    :rtype: :class:`Tensor`
    """
    if ptr is not None:
        dim = dim + src.dim() if dim < 0 else dim
        size = ([1] * dim) + [-1]
        ptr = ptr.view(size)
        src_max = gather_csr(segment_csr(src, ptr, reduce='max'), ptr)
        out = (src - src_max).exp()
        out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr)
    elif index is not None:
        N = maybe_num_nodes(index, num_nodes)
        src_max = scatter(src, index, dim, dim_size=N, reduce='max')
        src_max = src_max.index_select(dim, index)
        out = (src - src_max).exp()
        out_sum = scatter(out, index, dim, dim_size=N, reduce='sum')
        out_sum = out_sum.index_select(dim, index)
    else:
        raise NotImplementedError

    return out / (out_sum + 1e-16)
Example #18
0
 def forward(self, x_h, x_g, edge_index, edge_attr, u, batch_g):
     src, tgt = edge_index
     out = edge_attr
     out = torch.cat([x_h[src], edge_attr], dim=1)
     out = self.node_mlp_1(out)
     ns = torch.ones(len(out), 1).float().cuda()
     a = scatter(out, tgt, dim=0, dim_size=x_g.size(0), reduce='sum')  # mu
     out = torch.cat([x_g, a, u[batch_g]], dim=1)
     out = self.node_mlp_2(out)
     return out
Example #19
0
def test_scatter(reduce):
    torch.manual_seed(12345)

    src = torch.randn(8, 100, 32)
    index = torch.randint(0, 10, (100, ), dtype=torch.long)

    with torch_geometric.experimental_mode('scatter_reduce'):
        out1 = scatter(src, index, dim=1, reduce=reduce)
    out2 = torch_scatter.scatter(src, index, dim=1, reduce=reduce)
    assert torch.allclose(out1, out2, atol=1e-6)
Example #20
0
    def forward(self, z, pos, batch=None):
        assert z.dim() == 1 and z.dtype == torch.long
        batch = torch.zeros_like(z) if batch is None else batch

        h = self.embedding(z)

        edge_index = radius_graph(pos,
                                  r=self.cutoff,
                                  batch=batch,
                                  max_num_neighbors=1000)
        row, col = edge_index
        edge_vec = pos[row] - pos[col]
        edge_sh = o3.spherical_harmonics(self.Rs_sh, edge_vec,
                                         'component') / self.num_neighbors**0.5
        edge_len = edge_vec.norm(dim=1)
        edge_weight = self.radial(edge_len)
        edge_c = (pi * edge_len / self.cutoff).cos().add(1).div(2)

        for conv, act, shortcut in self.layers[:-1]:
            with torch.autograd.profiler.record_function("Layer"):
                if shortcut:
                    s = shortcut(h)

                h = conv(h, edge_index, edge_weight, edge_c,
                         edge_sh)  # convolution
                h = act(h)  # gate non linearity

                if shortcut:
                    m = shortcut.output_mask
                    h = 0.5**0.5 * s + (1 + (0.5**0.5 - 1) * m) * h

        with torch.autograd.profiler.record_function("Layer"):
            h = self.layers[-1](h, edge_index, edge_weight, edge_c, edge_sh)

        s = 0
        for i, (mul, l, p) in enumerate(self.Rs_out):
            assert mul == 1 and l == 0
            if p == 1:
                s += h[:, i]
            if p == -1:
                s += h[:, i].pow(2).mul(0.5)  # odd^2 = even
        h = s.view(-1, 1)

        if self.mean is not None and self.std is not None:
            h = h * self.std + self.mean

        if self.atomref is not None:
            h = h + self.atomref(z)

        out = scatter(h, batch, dim=0, reduce=self.readout)

        if self.scale is not None:
            out = self.scale * out

        return out
Example #21
0
    def forward(self, data):
        z = data.atomic_numbers.long()
        pos = data.pos
        if self.regress_forces:
            pos = pos.requires_grad_(True)
        batch = data.batch

        if self.otf_graph:
            edge_index, cell_offsets, neighbors = radius_graph_pbc(
                data, self.cutoff, 50, data.pos.device)
            data.edge_index = edge_index
            data.cell_offsets = cell_offsets
            data.neighbors = neighbors

        # TODO return distance computation in radius_graph_pbc to remove need
        # for get_pbc_distances call
        if self.use_pbc:
            assert z.dim() == 1 and z.dtype == torch.long

            out = get_pbc_distances(
                pos,
                data.edge_index,
                data.cell,
                data.cell_offsets,
                data.neighbors,
            )

            edge_index = out["edge_index"]
            edge_weight = out["distances"]
            edge_attr = self.distance_expansion(edge_weight)

            h = self.embedding(z)
            for interaction in self.interactions:
                h = h + interaction(h, edge_index, edge_weight, edge_attr)

            h = self.lin1(h)
            h = self.act(h)
            h = self.lin2(h)

            batch = torch.zeros_like(z) if batch is None else batch
            energy = scatter(h, batch, dim=0, reduce=self.readout)
        else:
            energy = super(SchNetWrap, self).forward(z, pos, batch)

        if self.regress_forces:
            forces = -1 * (torch.autograd.grad(
                energy,
                pos,
                grad_outputs=torch.ones_like(energy),
                create_graph=True,
            )[0])
            return energy, forces
        else:
            return energy
Example #22
0
    def aggregate(self, inputs, index, dim_size=None):

        # The axis along which to index number of nodes.
        node_dim = self.node_dim
        out = torch_scatter.scatter(inputs,
                                    index=index,
                                    dim=node_dim,
                                    reduce='mean',
                                    dim_size=dim_size)

        return out
Example #23
0
 def forward(self, x_h, x_g, edge_index, edge_attr, u, batch_h):
     src, tgt = edge_index
     out = edge_attr
     out = torch.cat([x_g[tgt], edge_attr], dim=1)
     out = self.node_mlp_1(out)
     ns = torch.ones(len(out), 1).float().cuda()
     n = scatter(ns, src, dim=0, dim_size=x_h.size(0), reduce='sum')  # num
     a = scatter(out, src, dim=0, dim_size=x_h.size(0), reduce='mean')  # mu
     b = torch.sqrt(1e-6 + F.relu(
         scatter(out**2, src, dim=0, dim_size=x_h.size(0), reduce='mean') -
         a**2))  # sigma
     c = scatter(
         (out - a[src])**3, src, dim=0, dim_size=x_h.size(0),
         reduce='mean') / b**3  #skewness
     d = scatter(
         (out - a[src])**4, src, dim=0, dim_size=x_h.size(0),
         reduce='mean') / b**4  #kurtosis
     out = torch.cat([x_h, n, a, b, c, d, u[batch_h]], dim=1)
     out = self.node_mlp_2(out)
     return out
Example #24
0
    def forward(self, data) -> torch.Tensor:
        node_atom = data['z']
        node_pos = data['pos']
        batch = data['batch']

        # The graph
        edge_src, edge_dst = radius_graph(node_pos,
                                          r=self.max_radius,
                                          batch=batch,
                                          max_num_neighbors=1000)

        # Edge attributes
        edge_vec = node_pos[edge_src] - node_pos[edge_dst]
        edge_sh = o3.spherical_harmonics(l=range(self.sh_lmax + 1),
                                         x=edge_vec,
                                         normalize=True,
                                         normalization='component')

        # Edge length embedding
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedding = soft_one_hot_linspace(
            edge_length,
            0.0,
            self.max_radius,
            self.num_basis,
            basis='smooth_finite',
            cutoff=True,
        ).mul(self.num_basis**0.5)

        node_input = node_pos.new_ones(node_pos.shape[0], 1)

        node_attr = node_atom.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3,
                                          4])[node_atom]
        node_attr = torch.nn.functional.one_hot(node_attr, 5).mul(5**0.5)

        node_outputs = self.mp(node_features=node_input,
                               node_attr=node_attr,
                               edge_src=edge_src,
                               edge_dst=edge_dst,
                               edge_attr=edge_sh,
                               edge_scalars=edge_length_embedding)

        node_outputs = node_outputs[:, 0] + node_outputs[:, 1].pow(2).mul(0.5)
        node_outputs = node_outputs.view(-1, 1)

        node_outputs = node_outputs.div(self.num_nodes**0.5)

        if self.atomref is not None:
            node_outputs = node_outputs + self.atomref[node_atom]
        # for target=7, MAE of 75eV

        outputs = scatter(node_outputs, batch, dim=0)

        return outputs
Example #25
0
 def determine_step(dr):
     steplengths = torch.norm(dr, dim=1)
     longest_steps = scatter(steplengths,
                             self.atoms.batch,
                             reduce="max")
     longest_steps = longest_steps[self.atoms.batch]
     maxstep = longest_steps.new_tensor(self.maxstep)
     scale = (longest_steps + 1e-7).reciprocal() * torch.min(
         longest_steps, maxstep)
     dr *= scale.unsqueeze(1)
     return dr * self.damping
Example #26
0
def get_node(x, segment, mode='mean'):
    assert x.ndim == 3 and segment.ndim == 2
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
    if isinstance(segment, np.ndarray):
        segment = torch.from_numpy(segment).to(torch.long)
    c = x.shape[2]
    x = x.reshape((-1, c))
    mask = segment.flatten()
    nodes = scatter(x, mask, dim=0, reduce=mode)
    return nodes.to(torch.float32)
Example #27
0
    def aggregate(self, inputs, index, ptr=None, dim_size=None):

        if self.aggr in ['add', 'mean', 'max', None]:
            return super(GenMessagePassing, self).aggregate(inputs, index, ptr, dim_size)

        elif self.aggr in ['softmax_sg', 'softmax', 'softmax_sum']:

            if self.learn_t:
                out = scatter_softmax(inputs*self.t, index, dim=self.node_dim)
            else:
                with torch.no_grad():
                    out = scatter_softmax(inputs*self.t, index, dim=self.node_dim)

            out = scatter(inputs*out, index, dim=self.node_dim,
                          dim_size=dim_size, reduce='sum')

            if self.aggr == 'softmax_sum':
                self.sigmoid_y = torch.sigmoid(self.y)
                degrees = degree(index, num_nodes=dim_size).unsqueeze(1)
                out = torch.pow(degrees, self.sigmoid_y) * out

            return out


        elif self.aggr in ['power', 'power_sum']:
            min_value, max_value = 1e-7, 1e1
            torch.clamp_(inputs, min_value, max_value)
            out = scatter(torch.pow(inputs, self.p), index, dim=self.node_dim,
                          dim_size=dim_size, reduce='mean')
            torch.clamp_(out, min_value, max_value)
            out = torch.pow(out, 1/self.p)

            if self.aggr == 'power_sum':
                self.sigmoid_y = torch.sigmoid(self.y)
                degrees = degree(index, num_nodes=dim_size).unsqueeze(1)
                out = torch.pow(degrees, self.sigmoid_y) * out

            return out

        else:
            raise NotImplementedError('To be implemented')
Example #28
0
 def _compute_score(self, all_clusters, backbone_features, semantic_logits):
     """ Score the clusters """
     if self._activate_scorer:
         x = []
         coords = []
         batch = []
         for i, cluster in enumerate(all_clusters):
             x.append(backbone_features[cluster])
             coords.append(self.input.coords[cluster])
             batch.append(i * torch.ones(cluster.shape[0]))
         batch_cluster = Data(
             x=torch.cat(x).cpu(),
             coords=torch.cat(coords).cpu(),
             batch=torch.cat(batch).cpu(),
         )
         score_backbone_out = self.Scorer(batch_cluster)
         if self._scorer_is_encoder:
             cluster_feats = score_backbone_out.x
         else:
             cluster_feats = scatter(score_backbone_out.x,
                                     score_backbone_out.batch.long().to(
                                         self.device),
                                     dim=0,
                                     reduce="max")
         cluster_scores = self.ScorerHead(cluster_feats).squeeze(-1)
     else:
         # Use semantic certainty as cluster confidence
         with torch.no_grad():
             cluster_semantic = []
             batch = []
             for i, cluster in enumerate(all_clusters):
                 cluster_semantic.append(semantic_logits[cluster, :])
                 batch.append(i * torch.ones(cluster.shape[0]))
             cluster_semantic = torch.cat(cluster_semantic)
             batch = torch.cat(batch)
             cluster_semantic = scatter(cluster_semantic,
                                        batch.long().to(self.device),
                                        dim=0,
                                        reduce="mean")
             cluster_scores = torch.max(cluster_semantic, 1)[0]
     return cluster_scores
Example #29
0
    def forward(self,
                ps,
                JsorShapes,
                ws,
                poses,
                batch,
                check_rotation=True,
                is_Rotation=False):
        batch_num = poses.shape[0]
        assert (batch_num == JsorShapes.shape[0])
        if JsorShapes.shape.numel() == batch_num * 10:  #is shapes
            Js = self.smpl.skeleton(JsorShapes)
        else:
            Js = JsorShapes

        # Rs = batch_rodrigues(poses.view(-1, 3)).view(-1, 24, 3, 3)
        if poses.numel() == batch_num * 24 * 3:
            Rs = batch_rodrigues(poses.view(-1, 3)).view(-1, 24, 3, 3)
            Js_transformed, A = batch_global_rigid_transformation(
                Rs, Js, self.smpl.parents, rotate_base=False)
        elif poses.numel() == batch_num * 24 * 9:
            #input poses are general matrix
            if not is_Rotation:
                ms = poses.reshape(-1, 3, 3)
                # use gram schmit regularization
                b1 = F.normalize(ms[:, :, 0], dim=1)
                dot_prod = torch.sum(b1 * ms[:, :, 1], dim=1, keepdim=True)
                b2 = F.normalize(ms[:, :, 1] - dot_prod * b1, dim=-1)
                b3 = torch.cross(b1, b2, dim=1)
                Rs = torch.stack([b1, b2, b3],
                                 dim=-1).reshape(batch_num, 24, 3, 3)
            else:
                Rs = poses.reshape(batch_num, 24, 3, 3)
            Js_transformed, A = batch_global_rigid_transformation(
                Rs, Js, self.smpl.parents, rotate_base=False)
        elif poses.numel() == batch_num * 24 * 16:
            A = poses.reshape(batch_num, 24, 4, 4)
            Js_transformed = None
            Rs = None

        # Js_transformed, A = batch_global_rigid_transformation(Rs, Js, self.smpl.parents, rotate_base = False)
        splitl = torch_scatter.scatter(batch.new_ones(batch.numel(), 1),
                                       batch,
                                       dim=0).cpu().numpy().reshape(-1).astype(
                                           np.int32).tolist()
        ws = ws.split(splitl, 0)
        T = torch.cat(
            [weight.matmul(a.reshape(24, 16)) for weight, a in zip(ws, A)],
            dim=0)
        T = T.reshape(-1, 4, 4)
        ps = torch.cat((ps, ps.new_ones(ps.shape[0], 1)), dim=-1).unsqueeze(-1)
        ps = torch.matmul(T, ps).squeeze(-1)
        return ps[:, 0:3], T, Rs, Js_transformed
Example #30
0
    def rand_prop(self, x, edge_index, edge_weight):
        edge_weight = self.normalize_adj(edge_index, edge_weight, x.shape[0])
        row, col = edge_index[0], edge_index[1]
        x = self.dropNode(x)

        y = x
        for i in range(self.order):
            x_source = x[col]
            x = scatter(x_source * edge_weight[:, None], row[:,None], dim=0, dim_size=x.shape[0], reduce='sum').detach_()
            #x = torch.spmm(adj, x).detach_()
            y.add_(x)
        return y.div_(self.order + 1.0).detach_()