Beispiel #1
0
def test_empty_inputs():
    tp = FullyConnectedTensorProduct('0e + 1e', '0e + 1e', '0e + 1e')
    out = tp(torch.randn(2, 1, 0, 1, 4), torch.randn(1, 2, 0, 3, 4))
    assert out.shape == (2, 2, 0, 3, 4)

    out = tp.right(torch.randn(1, 2, 0, 3, 4))
    assert out.shape == (1, 2, 0, 3, 4, 4)
Beispiel #2
0
def test_input_weights_python():
    irreps_in1 = Irreps("1e + 2e + 3x3o")
    irreps_in2 = Irreps("1e + 2e + 3x3o")
    irreps_out = Irreps("1e + 2e + 3x3o")
    # - shared_weights = False -
    m = FullyConnectedTensorProduct(irreps_in1,
                                    irreps_in2,
                                    irreps_out,
                                    internal_weights=False,
                                    shared_weights=False)
    bdim = random.randint(1, 3)
    x1 = irreps_in1.randn(bdim, -1)
    x2 = irreps_in2.randn(bdim, -1)
    w = [
        torch.randn((bdim, ) + ins.path_shape) for ins in m.instructions
        if ins.has_weight
    ]
    m(x1, x2, w)
    # - shared_weights = True -
    m = FullyConnectedTensorProduct(irreps_in1,
                                    irreps_in2,
                                    irreps_out,
                                    internal_weights=False,
                                    shared_weights=True)
    bdim = random.randint(1, 3)
    x1 = irreps_in1.randn(bdim, -1)
    x2 = irreps_in2.randn(bdim, -1)
    w = [
        torch.randn(ins.path_shape) for ins in m.instructions if ins.has_weight
    ]
    m(x1, x2, w)
Beispiel #3
0
    def __init__(self, irreps_out, num_z, lmax) -> None:
        super().__init__()
        self.num_z = num_z

        self.irreps_sh = o3.Irreps.spherical_harmonics(lmax)

        # to multiply the edge type one-hot with the spherical harmonics to get the edge attributes
        self.mul = TensorProduct(
            [(num_z**2, "0e")],
            self.irreps_sh,
            [(num_z**2, ir) for _, ir in self.irreps_sh],
            [
                (0, l, l, "uvu", False)
                for l in range(lmax + 1)
            ]
        )
        irreps_attr = self.mul.irreps_out

        irreps_mid = o3.Irreps("64x0e + 24x1e + 24x1o + 16x2e + 16x2o")
        irreps_out = o3.Irreps(irreps_out)

        self.tp1 = FullyConnectedTensorProduct(
            irreps_in1=self.irreps_sh,
            irreps_in2=irreps_attr,
            irreps_out=irreps_mid,
        )
        self.tp2 = FullyConnectedTensorProduct(
            irreps_in1=irreps_mid,
            irreps_in2=irreps_attr,
            irreps_out=irreps_out,
        )
Beispiel #4
0
    def __init__(self, irreps_node_input, irreps_node_attr, irreps_edge_attr,
                 irreps_node_output, fc_neurons, num_neighbors) -> None:
        super().__init__()
        self.irreps_node_input = o3.Irreps(irreps_node_input)
        self.irreps_node_attr = o3.Irreps(irreps_node_attr)
        self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)
        self.irreps_node_output = o3.Irreps(irreps_node_output)
        self.num_neighbors = num_neighbors

        self.sc = FullyConnectedTensorProduct(self.irreps_node_input,
                                              self.irreps_node_attr,
                                              self.irreps_node_output)

        self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input,
                                                self.irreps_node_attr,
                                                self.irreps_node_input)

        irreps_mid = []
        instructions = []
        for i, (mul, ir_in) in enumerate(self.irreps_node_input):
            for j, (_, ir_edge) in enumerate(self.irreps_edge_attr):
                for ir_out in ir_in * ir_edge:
                    if ir_out in self.irreps_node_output or ir_out == o3.Irrep(
                            0, 1):
                        k = len(irreps_mid)
                        irreps_mid.append((mul, ir_out))
                        instructions.append((i, j, k, 'uvu', True))
        irreps_mid = o3.Irreps(irreps_mid)
        irreps_mid, p, _ = irreps_mid.sort()

        assert irreps_mid.dim > 0, f"irreps_node_input={self.irreps_node_input} time irreps_edge_attr={self.irreps_edge_attr} produces nothing in irreps_node_output={self.irreps_node_output}"

        instructions = [(i_1, i_2, p[i_out], mode, train)
                        for i_1, i_2, i_out, mode, train in instructions]

        tp = TensorProduct(
            self.irreps_node_input,
            self.irreps_edge_attr,
            irreps_mid,
            instructions,
            internal_weights=False,
            shared_weights=False,
        )
        self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel],
                                    torch.nn.functional.silu)
        self.tp = tp

        self.lin2 = FullyConnectedTensorProduct(irreps_mid,
                                                self.irreps_node_attr,
                                                self.irreps_node_output)

        # inspired by https://arxiv.org/pdf/2002.10444.pdf
        self.alpha = FullyConnectedTensorProduct(irreps_mid,
                                                 self.irreps_node_attr, "0e")
        with torch.no_grad():
            self.alpha.weight.zero_()
        assert self.alpha.output_mask[
            0] == 1.0, f"irreps_mid={irreps_mid} and irreps_node_attr={self.irreps_node_attr} are not able to generate scalars"
Beispiel #5
0
def test_input_weights_jit():
    irreps_in1 = Irreps("1e + 2e + 3x3o")
    irreps_in2 = Irreps("1e + 2e + 3x3o")
    irreps_out = Irreps("1e + 2e + 3x3o")
    # - shared_weights = False -
    m = FullyConnectedTensorProduct(irreps_in1,
                                    irreps_in2,
                                    irreps_out,
                                    internal_weights=False,
                                    shared_weights=False)
    traced = assert_auto_jitable(m)
    x1 = irreps_in1.randn(2, -1)
    x2 = irreps_in2.randn(2, -1)
    w = torch.randn(2, m.weight_numel)
    with pytest.raises((RuntimeError, torch.jit.Error)):
        m(x1, x2)  # it should require weights
    with pytest.raises((RuntimeError, torch.jit.Error)):
        traced(x1, x2)  # it should also require weights
    with pytest.raises((RuntimeError, torch.jit.Error)):
        traced(x1, x2, w[0])  # it should reject insufficient weights
    # Does the trace give right results?
    assert torch.allclose(m(x1, x2, w), traced(x1, x2, w))

    # Confirm that weird batch dimensions give the same results
    for f in (m, traced):
        x1 = irreps_in1.randn(2, 1, 4, -1)
        x2 = irreps_in2.randn(2, 3, 1, -1)
        w = torch.randn(3, 4, f.weight_numel)
        assert torch.allclose(
            f(x1, x2, w).reshape(24, -1),
            f(
                x1.expand(2, 3, 4, -1).reshape(24, -1),
                x2.expand(2, 3, 4, -1).reshape(24, -1),
                w[None].expand(2, 3, 4, -1).reshape(24, -1)))
        assert torch.allclose(
            f.right(x2, w).reshape(24, -1),
            f.right(
                x2.expand(2, 3, 4, -1).reshape(24, -1),
                w[None].expand(2, 3, 4, -1).reshape(24, -1)).reshape(24, -1))

    # - shared_weights = True -
    m = FullyConnectedTensorProduct(irreps_in1,
                                    irreps_in2,
                                    irreps_out,
                                    internal_weights=False,
                                    shared_weights=True)
    traced = assert_auto_jitable(m)
    w = torch.randn(m.weight_numel)
    with pytest.raises((RuntimeError, torch.jit.Error)):
        m(x1, x2)  # it should require weights
    with pytest.raises((RuntimeError, torch.jit.Error)):
        traced(x1, x2)  # it should also require weights
    with pytest.raises((RuntimeError, torch.jit.Error)):
        traced(x1, x2, torch.randn(
            2, m.weight_numel))  # it should reject too many weights
    # Does the trace give right results?
    assert torch.allclose(m(x1, x2, w), traced(x1, x2, w))
Beispiel #6
0
    def __init__(self, irreps_in, irreps_out, irreps_sh, diameter, num_radial_basis, steps=(1.0, 1.0, 1.0), **kwargs):
        super().__init__()

        self.irreps_in = o3.Irreps(irreps_in)
        self.irreps_out = o3.Irreps(irreps_out)
        self.irreps_sh = o3.Irreps(irreps_sh)

        self.num_radial_basis = num_radial_basis

        # self-connection
        self.sc = Linear(self.irreps_in, self.irreps_out)

        # connection with neighbors
        r = diameter / 2

        s = math.floor(r / steps[0])
        x = torch.arange(-s, s + 1.0) * steps[0]

        s = math.floor(r / steps[1])
        y = torch.arange(-s, s + 1.0) * steps[1]

        s = math.floor(r / steps[2])
        z = torch.arange(-s, s + 1.0) * steps[2]

        lattice = torch.stack(torch.meshgrid(x, y, z), dim=-1)  # [x, y, z, R^3]
        self.register_buffer('lattice', lattice)

        if 'padding' not in kwargs:
            kwargs['padding'] = tuple(s // 2 for s in lattice.shape[:3])
        self.kwargs = kwargs

        emb = soft_one_hot_linspace(
            x=lattice.norm(dim=-1),
            start=0.0,
            end=r,
            number=self.num_radial_basis,
            basis='smooth_finite',
            cutoff=True,
        )
        self.register_buffer('emb', emb)

        sh = o3.spherical_harmonics(
            l=self.irreps_sh,
            x=lattice,
            normalize=True,
            normalization='component'
        )  # [x, y, z, irreps_sh.dim]
        self.register_buffer('sh', sh)

        self.tp = FullyConnectedTensorProduct(self.irreps_in, self.irreps_sh, self.irreps_out, shared_weights=False)

        self.weight = torch.nn.Parameter(torch.randn(self.num_radial_basis, self.tp.weight_numel))
Beispiel #7
0
    def __init__(self, irreps_node_input, irreps_node_attr, irreps_edge_attr,
                 irreps_node_output, fc_neurons, num_neighbors) -> None:
        super().__init__()
        self.irreps_node_input = o3.Irreps(irreps_node_input)
        self.irreps_node_attr = o3.Irreps(irreps_node_attr)
        self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)
        self.irreps_node_output = o3.Irreps(irreps_node_output)
        self.num_neighbors = num_neighbors

        self.sc = FullyConnectedTensorProduct(self.irreps_node_input,
                                              self.irreps_node_attr,
                                              self.irreps_node_output)

        self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input,
                                                self.irreps_node_attr,
                                                self.irreps_node_input)

        irreps_mid = []
        instructions = []
        for i, (mul, ir_in) in enumerate(self.irreps_node_input):
            for j, (_, ir_edge) in enumerate(self.irreps_edge_attr):
                for ir_out in ir_in * ir_edge:
                    if ir_out in self.irreps_node_output or ir_out == o3.Irrep(
                            0, 1):
                        k = len(irreps_mid)
                        irreps_mid.append((mul, ir_out))
                        instructions.append((i, j, k, 'uvu', True))
        irreps_mid = o3.Irreps(irreps_mid)
        irreps_mid, p, _ = irreps_mid.sort()

        instructions = [(i_1, i_2, p[i_out], mode, train)
                        for i_1, i_2, i_out, mode, train in instructions]

        tp = TensorProduct(
            self.irreps_node_input,
            self.irreps_edge_attr,
            irreps_mid,
            instructions,
            internal_weights=False,
            shared_weights=False,
        )
        self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel],
                                    torch.nn.functional.silu)
        self.tp = tp

        self.lin2 = FullyConnectedTensorProduct(irreps_mid,
                                                self.irreps_node_attr,
                                                self.irreps_node_output)
        self.lin3 = FullyConnectedTensorProduct(irreps_mid,
                                                self.irreps_node_attr, "0e")
    def __init__(self) -> None:
        super().__init__()
        self.irreps_sh = o3.Irreps.spherical_harmonics(3)
        irreps_mid = o3.Irreps("64x0e + 24x1e + 24x1o + 16x2e + 16x2o")
        irreps_out = o3.Irreps("0o + 6x0e")

        self.tp1 = FullyConnectedTensorProduct(
            irreps_in1=self.irreps_sh,
            irreps_in2=self.irreps_sh,
            irreps_out=irreps_mid,
        )
        self.tp2 = FullyConnectedTensorProduct(
            irreps_in1=irreps_mid,
            irreps_in2=self.irreps_sh,
            irreps_out=irreps_out,
        )
        self.irreps_out = self.tp2.irreps_out
Beispiel #9
0
def test_weight_view_for_instruction():
    irreps_in1 = Irreps("1e + 2e + 3x3o")
    irreps_in2 = Irreps("1e + 2e + 3x3o")
    irreps_out = Irreps("1e + 2e + 3x3o")
    x1 = irreps_in1.randn(2, -1)
    x2 = irreps_in2.randn(2, -1)
    m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out)

    # Find all paths to the first output
    ins_idexes = [i for i, ins in enumerate(m.instructions) if ins.i_out == 0]
    with torch.no_grad():
        for i in ins_idexes:
            m.weight_view_for_instruction(i).zero_()

    out = m(x1, x2)
    assert torch.all(out[:, :1] == 0.0)
    assert torch.any(out[:, 1:] > 0.0)
Beispiel #10
0
    def __init__(self, irreps_in, irreps_out, irreps_sh, dim_key):
        super().__init__()
        self.irreps_in = irreps_in.simplify()
        self.irreps_out = irreps_out.simplify()
        self.irreps_sh = irreps_sh.simplify()

        # self.si = Linear(self.irreps_in, self.irreps_out, internal_weights=True, shared_weights=True)
        self.si = FullyConnectedTensorProduct(self.irreps_in,
                                              o3.Irreps("5x0e"),
                                              self.irreps_out)

        # self.lin1 = Linear(self.irreps_in, self.irreps_in, internal_weights=True, shared_weights=True)
        self.lin1 = FullyConnectedTensorProduct(self.irreps_in,
                                                o3.Irreps("5x0e"),
                                                self.irreps_in)

        instr = []
        irreps = []
        for i_1, (mul_1, (l_1, p_1)) in enumerate(self.irreps_in):
            for i_2, (_, (l_2, p_2)) in enumerate(self.irreps_sh):
                for l_out in range(abs(l_1 - l_2), l_1 + l_2 + 1):
                    p_out = p_1 * p_2
                    if (l_out, p_out) in [(l, p)
                                          for _, (l, p) in self.irreps_out]:
                        r = (mul_1, l_out, p_out)
                        if r in irreps:
                            i_out = irreps.index(r)
                        else:
                            i_out = len(irreps)
                            irreps.append(r)
                        instr += [(i_1, i_2, i_out, 'uvu', True)]
        irreps = o3.Irreps(irreps)
        self.tp = TensorProduct(self.irreps_in,
                                self.irreps_sh,
                                irreps,
                                instr,
                                internal_weights=False,
                                shared_weights=False)

        self.tp_weight = torch.nn.Parameter(
            torch.randn(dim_key, self.tp.weight_numel))

        # self.lin2 = Linear(irreps, self.irreps_out, internal_weights=True, shared_weights=True)
        self.lin2 = FullyConnectedTensorProduct(irreps, o3.Irreps("5x0e"),
                                                self.irreps_out)
Beispiel #11
0
    def __init__(self, irreps_in1, irreps_out, irreps_in2=None, tp_rescale=True) -> None:
        super().__init__()

        self.irreps_in1 = irreps_in1
        self.irreps_out = irreps_out
        # Init irreps_in2
        if irreps_in2 == None:
            self.irreps_in2_provided = False
            self.irreps_in2 = Irreps("1x0e")
        else:
            self.irreps_in2_provided = True
            self.irreps_in2 = irreps_in2
        self.tp_rescale = tp_rescale

        # Build the layers
        self.tp = FullyConnectedTensorProduct(
            irreps_in1=self.irreps_in1,
            irreps_in2=self.irreps_in2,
            irreps_out=self.irreps_out, shared_weights=True, normalization='component')

        # For each zeroth order output irrep we need a bias
        # So first determine the order for each output tensor and their dims
        self.irreps_out_orders = [int(irrep_str[-2]) for irrep_str in str(irreps_out).split('+')]
        self.irreps_out_dims = [int(irrep_str.split('x')[0]) for irrep_str in str(irreps_out).split('+')]
        self.irreps_out_slices = irreps_out.slices()
        # Store tuples of slices and corresponding biases in a list
        self.biases = []
        self.biases_slices = []
        self.biases_slice_idx = []
        for slice_idx in range(len(self.irreps_out_orders)):
            if self.irreps_out_orders[slice_idx] == 0:
                out_slice = irreps_out.slices()[slice_idx]
                out_bias = torch.nn.Parameter(
                    torch.zeros(self.irreps_out_dims[slice_idx], dtype=self.tp.weight.dtype))
                self.biases += [out_bias]
                self.biases_slices += [out_slice]
                self.biases_slice_idx += [slice_idx]
        self.biases = torch.nn.ParameterList(self.biases)

        # Initialize the correction factors
        self.slices_sqrt_k = {}

        # Initialize similar to the torch.nn.Linear
        self.tensor_product_init()
def test_fully_connected():
    irreps_in1 = o3.Irreps("1e + 2e + 3x3o")
    irreps_in2 = o3.Irreps("1e + 2e + 3x3o")
    irreps_out = o3.Irreps("1e + 2e + 3x3o")
    m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out)
    print(m)
    m(torch.randn(irreps_in1.dim), torch.randn(irreps_in2.dim))

    assert_equivariant(m)
    assert_auto_jitable(m)
Beispiel #13
0
def WeightBalancedIrreps(irreps_in1_scalar, irreps_in2, sh = True):
    """
    Determines an irreps_in1 type of order irreps_in2.lmax that when used in a tensor product
    irreps_in1 x irreps_in2 -> irreps_in1
    would have the same number of weights as for a standard linear layer, e.g. a tensor product
    irreps_in1_scalar x "1x0e" -> irreps_in1_scaler
    """
    n = 1
    lmax = irreps_in2.lmax
    irreps_in1 = (Irreps.spherical_harmonics(lmax) * n).sort().irreps.simplify() if sh else BalancedIrreps(lmax, n)
    weight_numel1 = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_in1).weight_numel
    weight_numel_scalar = FullyConnectedTensorProduct(irreps_in1_scalar, Irreps("1x0e"), irreps_in1_scalar).weight_numel
    while weight_numel1 < weight_numel_scalar:  # TODO: somewhat suboptimal implementation...
        n += 1
        irreps_in1 = (Irreps.spherical_harmonics(lmax) * n).sort().irreps.simplify() if sh else BalancedIrreps(lmax, n)
        weight_numel1 = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_in1).weight_numel
    print('Determined irrep type:', irreps_in1)

    return Irreps(irreps_in1)
Beispiel #14
0
    def __init__(self, irreps_in, irreps_sh, irreps_out,
                 num_neighbors) -> None:
        super().__init__()

        self.num_neighbors = num_neighbors

        tp = FullyConnectedTensorProduct(
            irreps_in1=irreps_in,
            irreps_in2=irreps_sh,
            irreps_out=irreps_out,
            internal_weights=False,
            shared_weights=False,
        )
        self.fc = FullyConnectedNet([3, 256, tp.weight_numel], torch.relu)
        self.tp = tp
        self.irreps_out = self.tp.irreps_out
Beispiel #15
0
def test_weight_views():
    irreps_in1 = Irreps("1e + 2e + 3x3o")
    irreps_in2 = Irreps("1e + 2e + 3x3o")
    irreps_out = Irreps("1e + 2e + 3x3o")
    batchdim = 3
    x1 = irreps_in1.randn(batchdim, -1)
    x2 = irreps_in2.randn(batchdim, -1)
    # shared weights
    m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out)
    with torch.no_grad():
        for w in m.weight_views():
            w.zero_()
    assert torch.all(m(x1, x2) == 0.0)

    # unshared weights
    m = FullyConnectedTensorProduct(irreps_in1,
                                    irreps_in2,
                                    irreps_out,
                                    shared_weights=False)
    weights = torch.randn(batchdim, m.weight_numel)
    with torch.no_grad():
        for w in m.weight_views(weights):
            w.zero_()
    assert torch.all(m(x1, x2, weights) == 0.0)
Beispiel #16
0
class Convolution(torch.nn.Module):
    r"""convolution on voxels

    Parameters
    ----------
    irreps_in : `e3nn.o3.Irreps`
        input irreps

    irreps_out : `e3nn.o3.Irreps`
        output irreps

    irreps_sh : `e3nn.o3.Irreps`
        set typically to ``o3.Irreps.spherical_harmonics(lmax)``

    diameter : float
        diameter of the filter in physical units

    num_radial_basis : int
        number of radial basis functions

    steps : tuple of float
        size of the pixel in physical units
    """
    def __init__(self, irreps_in, irreps_out, irreps_sh, diameter, num_radial_basis, steps=(1.0, 1.0, 1.0), **kwargs):
        super().__init__()

        self.irreps_in = o3.Irreps(irreps_in)
        self.irreps_out = o3.Irreps(irreps_out)
        self.irreps_sh = o3.Irreps(irreps_sh)

        self.num_radial_basis = num_radial_basis

        # self-connection
        self.sc = Linear(self.irreps_in, self.irreps_out)

        # connection with neighbors
        r = diameter / 2

        s = math.floor(r / steps[0])
        x = torch.arange(-s, s + 1.0) * steps[0]

        s = math.floor(r / steps[1])
        y = torch.arange(-s, s + 1.0) * steps[1]

        s = math.floor(r / steps[2])
        z = torch.arange(-s, s + 1.0) * steps[2]

        lattice = torch.stack(torch.meshgrid(x, y, z), dim=-1)  # [x, y, z, R^3]
        self.register_buffer('lattice', lattice)

        if 'padding' not in kwargs:
            kwargs['padding'] = tuple(s // 2 for s in lattice.shape[:3])
        self.kwargs = kwargs

        emb = soft_one_hot_linspace(
            x=lattice.norm(dim=-1),
            start=0.0,
            end=r,
            number=self.num_radial_basis,
            basis='smooth_finite',
            cutoff=True,
        )
        self.register_buffer('emb', emb)

        sh = o3.spherical_harmonics(
            l=self.irreps_sh,
            x=lattice,
            normalize=True,
            normalization='component'
        )  # [x, y, z, irreps_sh.dim]
        self.register_buffer('sh', sh)

        self.tp = FullyConnectedTensorProduct(self.irreps_in, self.irreps_sh, self.irreps_out, shared_weights=False)

        self.weight = torch.nn.Parameter(torch.randn(self.num_radial_basis, self.tp.weight_numel))

    def kernel(self):
        weight = self.emb @ self.weight
        weight = weight / (self.sh.shape[0] * self.sh.shape[1] * self.sh.shape[2])
        kernel = self.tp.right(self.sh, weight)  # [x, y, z, irreps_in.dim, irreps_out.dim]
        kernel = torch.einsum('xyzio->oixyz', kernel)
        return kernel

    def forward(self, x):
        r"""
        Parameters
        ----------
        x : `torch.Tensor`
            tensor of shape ``(batch, irreps_in.dim, x, y, z)``

        Returns
        -------
        `torch.Tensor`
            tensor of shape ``(batch, irreps_out.dim, x, y, z)``
        """
        sc = self.sc(x.transpose(1, 4)).transpose(1, 4)

        return sc + torch.nn.functional.conv3d(x, self.kernel(), **self.kwargs)
Beispiel #17
0
def main():
    parser = argparse.ArgumentParser(prog="tensor_product_benchmark")
    parser.add_argument("--jit", type=t_or_f, default=True)
    parser.add_argument("--irreps",
                        type=str,
                        default="8x0e + 8x1e + 8x2e + 8x3o")
    parser.add_argument("--irreps-in1", type=str, default=None)
    parser.add_argument("--irreps-in2", type=str, default=None)
    parser.add_argument("--irreps-out", type=str, default=None)
    parser.add_argument("--cuda", type=t_or_f, default=True)
    parser.add_argument("--backward", type=t_or_f, default=True)
    parser.add_argument("--opt-ein", type=t_or_f, default=True)
    parser.add_argument("--specialized-code", type=t_or_f, default=True)
    parser.add_argument("--elementwise", action='store_true')
    parser.add_argument("-n", type=int, default=1000)
    parser.add_argument("--batch", type=int, default=10)

    args = parser.parse_args()

    device = 'cuda' if (torch.cuda.is_available() and args.cuda) else 'cpu'
    args.cuda = device == 'cuda'

    print("======= Benchmark with settings: ======")
    for key, val in vars(args).items():
        print(f"{key:>18} : {val}")
    print("=" * 40)

    irreps_in1 = Irreps(args.irreps_in1 if args.irreps_in1 else args.irreps)
    irreps_in2 = Irreps(args.irreps_in2 if args.irreps_in2 else args.irreps)
    irreps_out = Irreps(args.irreps_out if args.irreps_out else args.irreps)

    if args.elementwise:
        tp = ElementwiseTensorProduct(irreps_in1,
                                      irreps_in2,
                                      _specialized_code=args.specialized_code,
                                      _optimize_einsums=args.opt_ein)
        if args.backward:
            print(
                "Elementwise TP has no weights, cannot backward. Setting --backward False."
            )
            args.backward = False
    else:
        tp = FullyConnectedTensorProduct(
            irreps_in1,
            irreps_in2,
            irreps_out,
            _specialized_code=args.specialized_code,
            _optimize_einsums=args.opt_ein)
    tp = tp.to(device=device)
    assert len(tp.instructions) > 0, "Bad irreps, no instructions"
    print(f"Tensor product: {tp}")
    print("Instructions:")
    for ins in tp.instructions:
        print(f"  {ins}")

    # from https://pytorch.org/docs/master/_modules/torch/utils/benchmark/utils/timer.html#Timer.timeit
    warmup = max(int(args.n // 100), 1)

    inputs = iter([(irreps_in1.randn(args.batch, -1).to(device=device),
                    irreps_in2.randn(args.batch, -1).to(device=device))
                   for _ in range(args.n + warmup)])

    # compile
    if args.jit:
        tp = compile(tp)

    print("starting...")

    # tanh() forces it to realize the grad as a full size matrix rather than expanded (stride 0) ones
    t = Timer(
        stmt=("tp.zero_grad()\n"
              "out = tp(*next(inputs))\n" +
              ("out.tanh().sum().backward()\n" if args.backward else '')),
        globals={
            'tp': tp,
            'inputs': inputs
        })

    perloop = t.timeit(args.n)

    print()
    print(perloop)
Beispiel #18
0
class O3TensorProduct(torch.nn.Module):
    def __init__(self, irreps_in1, irreps_out, irreps_in2=None, tp_rescale=True) -> None:
        super().__init__()

        self.irreps_in1 = irreps_in1
        self.irreps_out = irreps_out
        # Init irreps_in2
        if irreps_in2 == None:
            self.irreps_in2_provided = False
            self.irreps_in2 = Irreps("1x0e")
        else:
            self.irreps_in2_provided = True
            self.irreps_in2 = irreps_in2
        self.tp_rescale = tp_rescale

        # Build the layers
        self.tp = FullyConnectedTensorProduct(
            irreps_in1=self.irreps_in1,
            irreps_in2=self.irreps_in2,
            irreps_out=self.irreps_out, shared_weights=True, normalization='component')

        # For each zeroth order output irrep we need a bias
        # So first determine the order for each output tensor and their dims
        self.irreps_out_orders = [int(irrep_str[-2]) for irrep_str in str(irreps_out).split('+')]
        self.irreps_out_dims = [int(irrep_str.split('x')[0]) for irrep_str in str(irreps_out).split('+')]
        self.irreps_out_slices = irreps_out.slices()
        # Store tuples of slices and corresponding biases in a list
        self.biases = []
        self.biases_slices = []
        self.biases_slice_idx = []
        for slice_idx in range(len(self.irreps_out_orders)):
            if self.irreps_out_orders[slice_idx] == 0:
                out_slice = irreps_out.slices()[slice_idx]
                out_bias = torch.nn.Parameter(
                    torch.zeros(self.irreps_out_dims[slice_idx], dtype=self.tp.weight.dtype))
                self.biases += [out_bias]
                self.biases_slices += [out_slice]
                self.biases_slice_idx += [slice_idx]
        self.biases = torch.nn.ParameterList(self.biases)

        # Initialize the correction factors
        self.slices_sqrt_k = {}

        # Initialize similar to the torch.nn.Linear
        self.tensor_product_init()

    def tensor_product_init(self) -> None:
        with torch.no_grad():
            # Determine fan_in for each slice, it could be that each output slice is updated via several instructions
            slices_fan_in = {} # fan_in per slice
            for weight, instr in zip(self.tp.weight_views(), self.tp.instructions):
                slice_idx = instr[2]
                mul_1, mul_2, mul_out = weight.shape
                fan_in = mul_1 * mul_2
                slices_fan_in[slice_idx] = (slices_fan_in[slice_idx] + fan_in if slice_idx in slices_fan_in.keys() else fan_in)

            # Do the initialization of the weights in each instruction
            for weight, instr in zip(self.tp.weight_views(), self.tp.instructions):
                # The tensor product in e3nn already normalizes proportional to 1 / sqrt(fan_in), and the weights are by
                # default initialized with unif(-1,1). However, we want to be consistent with torch.nn.Linear and
                # initialize the weights with unif(-sqrt(k),sqrt(k)), with k = 1 / fan_in
                if self.tp_rescale:
                    sqrt_k = 1 / slices_fan_in[slice_idx] ** 0.5
                else:
                    sqrt_k = 1.
                weight.data.uniform_(-sqrt_k, sqrt_k)
                self.slices_sqrt_k[slice_idx] = (self.irreps_out_slices[slice_idx], sqrt_k)

            # Initialize the biases
            for (out_slice_idx, out_slice, out_bias) in zip(self.biases_slice_idx, self.biases_slices, self.biases):
                sqrt_k = 1 / slices_fan_in[out_slice_idx] ** 0.5
                out_bias.uniform_(-sqrt_k, sqrt_k)

    def forward_tp_rescale_bias(self, data_in1, data_in2=None) -> torch.Tensor:
        if data_in2 == None:
            data_in2 = torch.ones_like(data_in1[:, 0:1])

        data_out = self.tp(data_in1, data_in2)
        # Apply corrections
        if self.tp_rescale:
            for (slice, slice_sqrt_k) in self.slices_sqrt_k.values():
                data_out[:,slice] /= slice_sqrt_k
        # Add the biases
        for (_, slice, bias) in zip(self.biases_slice_idx, self.biases_slices, self.biases):
            data_out[:,slice] += bias
        # Return result
        return data_out

    def forward(self, data_in1, data_in2=None) -> torch.Tensor:
        # Apply the tensor product, the rescaling and the bias
        data_out = self.forward_tp_rescale_bias(data_in1, data_in2)
        return data_out
def main():
    parser = argparse.ArgumentParser(prog="tensor_product_benchmark")
    parser.add_argument("--jit", type=t_or_f, default=True)
    parser.add_argument("--irreps-in1",
                        type=str,
                        default="8x0e + 8x1e + 8x2e + 8x3e")
    parser.add_argument("--irreps-in2",
                        type=str,
                        default="8x0e + 8x1e + 8x2e + 8x3e")
    parser.add_argument("--irreps-out",
                        type=str,
                        default="8x0e + 8x1e + 8x2e + 8x3e")
    parser.add_argument("--cuda", type=t_or_f, default=True)
    parser.add_argument("--backward", type=t_or_f, default=True)
    parser.add_argument("--opt-ein", type=t_or_f, default=True)
    parser.add_argument("--specialized-code", type=t_or_f, default=True)
    parser.add_argument("-w", type=int, default=10)
    parser.add_argument("-n", type=int, default=3)
    parser.add_argument("--batch", type=int, default=10)

    args = parser.parse_args()

    device = 'cuda' if (torch.cuda.is_available() and args.cuda) else 'cpu'
    args.cuda = device == 'cuda'

    if args.cuda:
        # Workaround for CUDA driver issues
        # See https://github.com/pytorch/pytorch/issues/60158#issuecomment-866294291
        with torch.profiler.profile() as _:
            pass

    print("======= Benchmark with settings: ======")
    for key, val in vars(args).items():
        print(f"{key:>18} : {val}")
    print("=" * 40)

    irreps_in1 = Irreps(args.irreps_in1)
    irreps_in2 = Irreps(args.irreps_in2)
    irreps_out = Irreps(args.irreps_out)
    tp = FullyConnectedTensorProduct(irreps_in1,
                                     irreps_in2,
                                     irreps_out,
                                     _specialized_code=args.specialized_code,
                                     _optimize_einsums=args.opt_ein)
    tp = tp.to(device=device)

    inputs = [(irreps_in1.randn(args.batch, -1).to(device=device),
               irreps_in2.randn(args.batch, -1).to(device=device))
              for _ in range(1 + args.w + args.n)]
    if args.backward:
        for tmp in inputs:
            for t in tmp:
                t.requires_grad_(True)
    inputs = iter(inputs)

    # compile
    if args.jit:
        print("JITing...")
        tp = compile(tp)

    print("starting...")

    called_num = [0]

    def trace_handler(p):
        print(p.key_averages().table(sort_by="self_cuda_time_total",
                                     row_limit=-1))
        p.export_chrome_trace("test_trace_" + str(called_num[0]) + ".json")
        called_num[0] += 1

    with torch.profiler.profile(activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA
    ],
                                schedule=torch.profiler.schedule(
                                    wait=1, warmup=args.w, active=args.n),
                                on_trace_ready=trace_handler) as p:
        for _ in range(1 + args.w + args.n):
            out = tp(*next(inputs))
            if args.backward:
                # tanh() forces it to realize the grad as a full size matrix rather than expanded (stride 0) ones
                out.tanh().sum().backward()
            p.step()
Beispiel #20
0
    def __init__(self,
                 muls=(256, 16, 0),
                 lmax=1,
                 num_layers=3,
                 cutoff=10.0,
                 rad_gaussians=50,
                 rad_hs=(128, 128),
                 num_neighbors=20,
                 num_atoms=20,
                 mean=None,
                 std=None,
                 scale=None,
                 atomref=None):
        super().__init__()

        self.cutoff = cutoff
        self.mean = mean
        self.std = std
        self.scale = scale
        self.num_neighbors = num_neighbors
        self.num_atoms = num_atoms
        self.rad_gaussians = rad_gaussians
        self.cutoff = cutoff

        self.radial = FullyConnectedNet((rad_gaussians, ) + rad_hs,
                                        swish,
                                        variance_in=1 / rad_gaussians,
                                        out_act=True)
        self.irreps_sh = o3.Irreps.spherical_harmonics(
            lmax)  # spherical harmonics representation
        # self.irreps_edge = o3.Irreps([(25, l, (-1)**l) for l in range(lmax + 1)])
        self.irreps_edge = self.irreps_sh

        # self.mul = TensorProduct(
        #     [(25, "0e", 1.0)],
        #     [(1, ir, 1.0) for _, ir in self.irreps_sh],
        #     [(25, ir, 1.0) for _, ir in self.irreps_sh],
        #     [
        #         (0, l, l, "uvu", False, 1.0)
        #         for l in range(lmax + 1)
        #     ]
        # )
        irreps = o3.Irreps([(muls[0], (0, 1)), (muls[1], (1, -1)),
                            (muls[2], (2, 1))])
        self.mul_node = FullyConnectedTensorProduct([(5, "0e")],
                                                    self.irreps_sh, irreps)

        modules = []
        for _ in range(num_layers):
            act = make_gated_block(irreps, muls, self.irreps_sh)
            conv = Conv(irreps, act.irreps_in, self.irreps_edge, rad_hs[-1])
            irreps = act.irreps_out.simplify()

            modules += [torch.nn.ModuleList([conv, act])]

        self.layers = torch.nn.ModuleList(modules)

        self.irreps_out = o3.Irreps("0e + 0o")
        self.layers.append(
            Conv(irreps, self.irreps_out, self.irreps_edge, rad_hs[-1]))

        self.register_buffer('atomref', atomref)
Beispiel #21
0
def test_empty_irreps():
    tp = FullyConnectedTensorProduct('0e + 1e', Irreps([]), '0e + 1e')
    out = tp(torch.randn(1, 2, 4), torch.randn(2, 1, 0))
    assert out.shape == (2, 2, 4)