예제 #1
0
def test():
    from torch_cluster import radius
    from e3nn.math import soft_one_hot_linspace

    conv = Convolution(
        irreps_node_input='0e + 1e',
        irreps_node_output='0e + 1e',
        irreps_node_attr_input='2x0e',
        irreps_node_attr_output='3x0e',
        irreps_edge_attr='0e + 1e',
        num_edge_scalar_attr=4,
        radial_layers=1,
        radial_neurons=50,
        num_neighbors=3.0,
    )

    pos_in = torch.randn(5, 3)
    pos_out = torch.randn(2, 3)

    node_input = torch.randn(5, 4)
    node_attr_input = torch.randn(5, 2)
    node_attr_output = torch.randn(2, 3)

    edge_src, edge_dst = radius(pos_out, pos_in, r=2.0)
    edge_vec = pos_in[edge_src] - pos_out[edge_dst]
    edge_attr = o3.spherical_harmonics([0, 1], edge_vec, True)
    edge_scalar_attr = soft_one_hot_linspace(x=edge_vec.norm(dim=1),
                                             start=0.0,
                                             end=2.0,
                                             number=4,
                                             basis='smooth_finite',
                                             cutoff=True)

    conv(node_input, node_attr_input, node_attr_output, edge_src, edge_dst,
         edge_attr, edge_scalar_attr)
예제 #2
0
    def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor:
        batch, node_inputs, edge_src, edge_dst, edge_vec = self.preprocess(data)
        del data

        edge_attr = o3.spherical_harmonics(range(self.lmax + 1), edge_vec, True, normalization='component')

        # Edge length embedding
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedding = soft_one_hot_linspace(
            edge_length,
            0.0,
            self.max_radius,
            self.number_of_basis,
            basis='smooth_finite',  # the smooth_finite basis with cutoff = True goes to zero at max_radius
            cutoff=True,  # no need for an additional smooth cutoff
        ).mul(self.number_of_basis**0.5)

        # Node attributes are not used here
        node_attr = node_inputs.new_ones(node_inputs.shape[0], 1)

        node_outputs = self.mp(node_inputs, node_attr, edge_src, edge_dst, edge_attr, edge_length_embedding)

        if self.pool_nodes:
            return scatter(node_outputs, batch, int(batch.max()) + 1).div(self.num_nodes**0.5)
        else:
            return node_outputs
예제 #3
0
    def forward(self, data: Union[Data, Dict[str,
                                             torch.Tensor]]) -> torch.Tensor:
        """evaluate the network

        Parameters
        ----------
        data : `torch_geometric.data.Data` or dict
            data object containing
            - ``pos`` the position of the nodes (atoms)
            - ``x`` the input features of the nodes, optional
            - ``z`` the attributes of the nodes, for instance the atom type, optional
            - ``batch`` the graph to which the node belong, optional
        """
        if 'batch' in data:
            batch = data['batch']
        else:
            batch = data['pos'].new_zeros(data['pos'].shape[0],
                                          dtype=torch.long)

        edge_index = radius_graph(data['pos'], self.max_radius, batch)
        edge_src = edge_index[0]
        edge_dst = edge_index[1]
        edge_vec = data['pos'][edge_src] - data['pos'][edge_dst]
        edge_sh = o3.spherical_harmonics(self.irreps_edge_attr,
                                         edge_vec,
                                         True,
                                         normalization='component')
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedded = soft_one_hot_linspace(
            x=edge_length,
            start=0.0,
            end=self.max_radius,
            number=self.number_of_basis,
            basis='gaussian',
            cutoff=False).mul(self.number_of_basis**0.5)
        edge_attr = smooth_cutoff(
            edge_length / self.max_radius)[:, None] * edge_sh

        if self.input_has_node_in and 'x' in data:
            assert self.irreps_in is not None
            x = data['x']
        else:
            assert self.irreps_in is None
            x = data['pos'].new_ones((data['pos'].shape[0], 1))

        if self.input_has_node_attr and 'z' in data:
            z = data['z']
        else:
            assert self.irreps_node_attr == o3.Irreps("0e")
            z = data['pos'].new_ones((data['pos'].shape[0], 1))

        for lay in self.layers:
            x = lay(x, z, edge_src, edge_dst, edge_attr, edge_length_embedded)

        if self.reduce_output:
            return scatter(x, batch, dim=0).div(self.num_nodes**0.5)
        else:
            return x
예제 #4
0
def test_zero_out(basis):
    x1 = torch.linspace(-2.0, -1.1, 20)
    x2 = torch.linspace(2.1, 3.0, 20)
    x = torch.cat([x1, x2])

    y = soft_one_hot_linspace(x, -1.0, 2.0, 5, basis, cutoff=True)
    if basis == 'gaussian':
        assert y.abs().max() < 0.22
    else:
        assert y.abs().max() == 0.0
예제 #5
0
    def forward(self, data) -> torch.Tensor:
        node_atom = data['z']
        node_pos = data['pos']
        batch = data['batch']

        # The graph
        edge_src, edge_dst = radius_graph(node_pos,
                                          r=self.max_radius,
                                          batch=batch,
                                          max_num_neighbors=1000)

        # Edge attributes
        edge_vec = node_pos[edge_src] - node_pos[edge_dst]
        edge_sh = o3.spherical_harmonics(l=range(self.sh_lmax + 1),
                                         x=edge_vec,
                                         normalize=True,
                                         normalization='component')

        # Edge length embedding
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedding = soft_one_hot_linspace(
            edge_length,
            0.0,
            self.max_radius,
            self.num_basis,
            basis='smooth_finite',
            cutoff=True,
        ).mul(self.num_basis**0.5)

        node_input = node_pos.new_ones(node_pos.shape[0], 1)

        node_attr = node_atom.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3,
                                          4])[node_atom]
        node_attr = torch.nn.functional.one_hot(node_attr, 5).mul(5**0.5)

        node_outputs = self.mp(node_features=node_input,
                               node_attr=node_attr,
                               edge_src=edge_src,
                               edge_dst=edge_dst,
                               edge_attr=edge_sh,
                               edge_scalars=edge_length_embedding)

        node_outputs = node_outputs[:, 0] + node_outputs[:, 1].pow(2).mul(0.5)
        node_outputs = node_outputs.view(-1, 1)

        node_outputs = node_outputs.div(self.num_nodes**0.5)

        if self.atomref is not None:
            node_outputs = node_outputs + self.atomref[node_atom]
        # for target=7, MAE of 75eV

        outputs = scatter(node_outputs, batch, dim=0)

        return outputs
예제 #6
0
    def __init__(self, irreps_in, irreps_out, irreps_sh, diameter, num_radial_basis, steps=(1.0, 1.0, 1.0), **kwargs):
        super().__init__()

        self.irreps_in = o3.Irreps(irreps_in)
        self.irreps_out = o3.Irreps(irreps_out)
        self.irreps_sh = o3.Irreps(irreps_sh)

        self.num_radial_basis = num_radial_basis

        # self-connection
        self.sc = Linear(self.irreps_in, self.irreps_out)

        # connection with neighbors
        r = diameter / 2

        s = math.floor(r / steps[0])
        x = torch.arange(-s, s + 1.0) * steps[0]

        s = math.floor(r / steps[1])
        y = torch.arange(-s, s + 1.0) * steps[1]

        s = math.floor(r / steps[2])
        z = torch.arange(-s, s + 1.0) * steps[2]

        lattice = torch.stack(torch.meshgrid(x, y, z), dim=-1)  # [x, y, z, R^3]
        self.register_buffer('lattice', lattice)

        if 'padding' not in kwargs:
            kwargs['padding'] = tuple(s // 2 for s in lattice.shape[:3])
        self.kwargs = kwargs

        emb = soft_one_hot_linspace(
            x=lattice.norm(dim=-1),
            start=0.0,
            end=r,
            number=self.num_radial_basis,
            basis='smooth_finite',
            cutoff=True,
        )
        self.register_buffer('emb', emb)

        sh = o3.spherical_harmonics(
            l=self.irreps_sh,
            x=lattice,
            normalize=True,
            normalization='component'
        )  # [x, y, z, irreps_sh.dim]
        self.register_buffer('sh', sh)

        self.tp = FullyConnectedTensorProduct(self.irreps_in, self.irreps_sh, self.irreps_out, shared_weights=False)

        self.weight = torch.nn.Parameter(torch.randn(self.num_radial_basis, self.tp.weight_numel))
예제 #7
0
    def forward(self, data: Union[Data, Dict[str,
                                             torch.Tensor]]) -> torch.Tensor:
        if 'batch' in data:
            batch = data['batch']
        else:
            batch = data['pos'].new_zeros(data['pos'].shape[0],
                                          dtype=torch.long)

        # The graph
        edge_src = data['edge_index'][0]
        edge_dst = data['edge_index'][1]

        # Edge attributes
        edge_vec = data['pos'][edge_src] - data['pos'][edge_dst]
        edge_sh = o3.spherical_harmonics(range(self.lmax + 1),
                                         edge_vec,
                                         True,
                                         normalization='component')
        edge_attr = torch.cat([data['edge_attr'], edge_sh], dim=1)

        # Edge length embedding
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedding = soft_one_hot_linspace(
            edge_length,
            0.0,
            self.max_radius,
            self.number_of_basis,
            basis=
            'cosine',  # the cosine basis with cutoff = True goes to zero at max_radius
            cutoff=True,  # no need for an additional smooth cutoff
        ).mul(self.number_of_basis**0.5)

        node_outputs = self.mp(data['node_input'], data['node_attr'], edge_src,
                               edge_dst, edge_attr, edge_length_embedding)

        if self.pool_nodes:
            return scatter(node_outputs, batch, dim=0).div(self.num_nodes**0.5)
        else:
            return node_outputs
예제 #8
0
    def forward(self, data) -> torch.Tensor:
        num_nodes = 4  # typical number of nodes

        edge_src, edge_dst = radius_graph(x=data.pos, r=2.5, batch=data.batch)
        edge_vec = data.pos[edge_src] - data.pos[edge_dst]
        edge_attr = o3.spherical_harmonics(l=self.irreps_sh,
                                           x=edge_vec,
                                           normalize=True,
                                           normalization='component')
        edge_length_embedded = soft_one_hot_linspace(x=edge_vec.norm(dim=1),
                                                     start=0.5,
                                                     end=2.5,
                                                     number=3,
                                                     basis='smooth_finite',
                                                     cutoff=True) * 3**0.5

        x = scatter(edge_attr, edge_dst, dim=0).div(self.num_neighbors**0.5)

        x = self.conv(x, edge_src, edge_dst, edge_attr, edge_length_embedded)
        x = self.gate(x)
        x = self.final(x, edge_src, edge_dst, edge_attr, edge_length_embedded)

        return scatter(x, data.batch, dim=0).div(num_nodes**0.5)
예제 #9
0
def test_normalized(basis, cutoff):
    x = torch.linspace(-14.0, 105.0, 50)
    y = soft_one_hot_linspace(x, -20.0, 120.0, 12, basis, cutoff)

    assert 0.4 < y.pow(2).sum(1).min()
    assert y.pow(2).sum(1).max() < 2.0
예제 #10
0
    def forward(self, node_atom, node_pos, batch) -> torch.Tensor:
        # The graph
        edge_src, edge_dst = radius_graph(
            node_pos,
            r=self.max_radius,
            batch=batch,
            max_num_neighbors=1000
        )

        # Edge attributes
        edge_vec = node_pos[edge_src] - node_pos[edge_dst]
        edge_sh = o3.spherical_harmonics(
            l=range(self.lmax + 1),
            x=edge_vec,
            normalize=True,
            normalization='component'
        )

        # Edge length embedding
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedding = soft_one_hot_linspace(
            edge_length,
            0.0,
            self.max_radius,
            self.number_of_basis,
            basis='cosine',  # the cosine basis with cutoff = True goes to zero at max_radius
            cutoff=True,  # no need for an additional smooth cutoff
        ).mul(self.number_of_basis**0.5)

        node_input = node_pos.new_ones(node_pos.shape[0], 1)

        node_attr = node_atom.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3, 4])[node_atom]
        node_attr = torch.nn.functional.one_hot(node_attr, 5).mul(5**0.5)

        node_outputs = self.mp(
            node_features=node_input,
            node_attr=node_attr,
            edge_src=edge_src,
            edge_dst=edge_dst,
            edge_attr=edge_sh,
            edge_scalars=edge_length_embedding
        )

        node_outputs = node_outputs[:, 0] + node_outputs[:, 1].pow(2).mul(0.5)
        node_outputs = node_outputs.view(-1, 1)

        node_outputs = node_outputs.div(self.num_nodes**0.5)

        if self.mean is not None and self.std is not None:
            node_outputs = node_outputs * self.std + self.mean

        if self.atomref is not None:
            node_outputs = node_outputs + self.atomref[node_atom]
        # for target=7, MAE of 75eV

        outputs = scatter(node_outputs, batch, dim=0)

        if self.scale is not None:
            outputs = self.scale * outputs

        return outputs
예제 #11
0
    def forward(self, z, pos, batch=None):
        assert z.dim() == 1 and z.dtype == torch.long
        assert pos.dim() == 2 and pos.shape[1] == 3
        batch = torch.zeros_like(z) if batch is None else batch

        edge_src, edge_dst = radius_graph(pos,
                                          r=self.cutoff,
                                          batch=batch,
                                          max_num_neighbors=1000)
        edge_vec = pos[edge_src] - pos[edge_dst]
        edge_sh = o3.spherical_harmonics(self.irreps_sh, edge_vec, True,
                                         'component')
        edge_len = edge_vec.norm(dim=1)
        edge_len_emb = self.radial(
            soft_one_hot_linspace(edge_len, 0.0, self.cutoff,
                                  self.rad_gaussians))
        edge_c = (pi * edge_len / self.cutoff).cos().add(1).div(2)
        edge_sh = edge_c[:, None] * edge_sh / self.num_neighbors**0.5

        # z : [1, 6, 7, 8, 9] -> [0, 1, 2, 3, 4]
        node_z = z.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3, 4])[z]
        # edge_zz = 5 * node_z[edge_src] + node_z[edge_dst]

        node_z = torch.nn.functional.one_hot(node_z, 5).mul(5**0.5)
        # edge_zz = torch.nn.functional.one_hot(edge_zz, 25).mul(5.0)

        # edge_attr = self.mul(edge_zz, edge_sh)
        edge_attr = edge_sh

        h = scatter(edge_sh, edge_src, dim=0, dim_size=len(pos))
        h[:, 0] = 1
        h = self.mul_node(node_z, h)

        print_std('h', h)

        for conv, act in self.layers[:-1]:
            with torch.autograd.profiler.record_function("Layer"):
                h = conv(h, node_z, edge_src, edge_dst, edge_len_emb,
                         edge_attr)  # convolution
                print_std('post conv', h)
                h = act(h)  # gate non linearity
                print_std('post gate', h)

        with torch.autograd.profiler.record_function("Layer"):
            h = self.layers[-1](h, node_z, edge_src, edge_dst, edge_len_emb,
                                edge_attr)

        print_std('h out', h)

        s = 0
        for i, (mul, (l, p)) in enumerate(self.irreps_out):
            assert mul == 1 and l == 0
            if p == 1:
                s += h[:, i]
            if p == -1:
                s += h[:, i].pow(2).mul(0.5)  # odd^2 = even
        h = s.view(-1, 1)

        print_std('h out+', h)

        # for the scatter we normalize
        h = h / self.num_atoms**0.5

        if self.mean is not None and self.std is not None:
            h = h * self.std + self.mean

        if self.atomref is not None:
            h = h + self.atomref[z]
        # for target=7, MAE of 75eV

        out = scatter(h, batch, dim=0)

        if self.scale is not None:
            out = self.scale * out

        return out