Example #1
0
    def __init__(self, node_in_irreps, node_hidden_irreps, node_out_irreps, attr_irreps, update_pos=False,
                 recurrent=True, infer_edges=False, edge_weight=False):
        super(SEGNN, self).__init__(node_dim=-2, aggr="add")

        self.update_pos = update_pos
        self.recurrent = recurrent
        self.infer_edges = infer_edges
        self.edge_weight = edge_weight

        # The message network layers
        irreps_message_in = (node_in_irreps + node_in_irreps + Irreps("1x0e")).simplify()
        self.message_layer_1 = O3TensorProductSwishGate(irreps_message_in,
                                                        node_hidden_irreps,
                                                        attr_irreps)
        self.message_layer_2 = O3TensorProductSwishGate(node_hidden_irreps,
                                                        node_hidden_irreps,
                                                        attr_irreps)

        # The node update layers
        irreps_update_in = (node_in_irreps + node_hidden_irreps).simplify()
        self.update_layer_1 = O3TensorProductSwishGate(irreps_update_in,
                                                       node_hidden_irreps,
                                                       attr_irreps)
        self.update_layer_2 = O3TensorProduct(node_hidden_irreps,
                                              node_out_irreps,
                                              attr_irreps)

        # Position update network
        if self.update_pos:  # TODO: currently not updated...
            self.pos_update_layer_1 = None  # O3TensorProductSwishGate
            self.pos_update_layer_2 = None  # O3TensorProduct

        if self.infer_edges:
            self.inf_net_1 = O3TensorProduct(node_hidden_irreps, Irreps("1x0e"), attr_irreps)
            self.inf_net_2 = nn.Sigmoid()
Example #2
0
def test_input_weights_python():
    irreps_in1 = Irreps("1e + 2e + 3x3o")
    irreps_in2 = Irreps("1e + 2e + 3x3o")
    irreps_out = Irreps("1e + 2e + 3x3o")
    # - shared_weights = False -
    m = FullyConnectedTensorProduct(irreps_in1,
                                    irreps_in2,
                                    irreps_out,
                                    internal_weights=False,
                                    shared_weights=False)
    bdim = random.randint(1, 3)
    x1 = irreps_in1.randn(bdim, -1)
    x2 = irreps_in2.randn(bdim, -1)
    w = [
        torch.randn((bdim, ) + ins.path_shape) for ins in m.instructions
        if ins.has_weight
    ]
    m(x1, x2, w)
    # - shared_weights = True -
    m = FullyConnectedTensorProduct(irreps_in1,
                                    irreps_in2,
                                    irreps_out,
                                    internal_weights=False,
                                    shared_weights=True)
    bdim = random.randint(1, 3)
    x1 = irreps_in1.randn(bdim, -1)
    x2 = irreps_in2.randn(bdim, -1)
    w = [
        torch.randn(ins.path_shape) for ins in m.instructions if ins.has_weight
    ]
    m(x1, x2, w)
Example #3
0
def test_equivariant():
    # Confirm that a compiled tensorproduct is still equivariant
    irreps_in = Irreps("1e + 2e + 3x3o")
    irreps_out = Irreps("1e + 2e + 3x3o")
    mod = Linear(irreps_in, irreps_out)
    mod_script = compile(mod)
    assert_equivariant(
        mod_script,
        # we provide explicit irreps because infering on a script module is not reliable
        irreps_in=irreps_in,
        irreps_out=irreps_out
    )
Example #4
0
def test_input_weights_jit():
    irreps_in1 = Irreps("1e + 2e + 3x3o")
    irreps_in2 = Irreps("1e + 2e + 3x3o")
    irreps_out = Irreps("1e + 2e + 3x3o")
    # - shared_weights = False -
    m = FullyConnectedTensorProduct(irreps_in1,
                                    irreps_in2,
                                    irreps_out,
                                    internal_weights=False,
                                    shared_weights=False)
    traced = assert_auto_jitable(m)
    x1 = irreps_in1.randn(2, -1)
    x2 = irreps_in2.randn(2, -1)
    w = torch.randn(2, m.weight_numel)
    with pytest.raises((RuntimeError, torch.jit.Error)):
        m(x1, x2)  # it should require weights
    with pytest.raises((RuntimeError, torch.jit.Error)):
        traced(x1, x2)  # it should also require weights
    with pytest.raises((RuntimeError, torch.jit.Error)):
        traced(x1, x2, w[0])  # it should reject insufficient weights
    # Does the trace give right results?
    assert torch.allclose(m(x1, x2, w), traced(x1, x2, w))

    # Confirm that weird batch dimensions give the same results
    for f in (m, traced):
        x1 = irreps_in1.randn(2, 1, 4, -1)
        x2 = irreps_in2.randn(2, 3, 1, -1)
        w = torch.randn(3, 4, f.weight_numel)
        assert torch.allclose(
            f(x1, x2, w).reshape(24, -1),
            f(
                x1.expand(2, 3, 4, -1).reshape(24, -1),
                x2.expand(2, 3, 4, -1).reshape(24, -1),
                w[None].expand(2, 3, 4, -1).reshape(24, -1)))
        assert torch.allclose(
            f.right(x2, w).reshape(24, -1),
            f.right(
                x2.expand(2, 3, 4, -1).reshape(24, -1),
                w[None].expand(2, 3, 4, -1).reshape(24, -1)).reshape(24, -1))

    # - shared_weights = True -
    m = FullyConnectedTensorProduct(irreps_in1,
                                    irreps_in2,
                                    irreps_out,
                                    internal_weights=False,
                                    shared_weights=True)
    traced = assert_auto_jitable(m)
    w = torch.randn(m.weight_numel)
    with pytest.raises((RuntimeError, torch.jit.Error)):
        m(x1, x2)  # it should require weights
    with pytest.raises((RuntimeError, torch.jit.Error)):
        traced(x1, x2)  # it should also require weights
    with pytest.raises((RuntimeError, torch.jit.Error)):
        traced(x1, x2, torch.randn(
            2, m.weight_numel))  # it should reject too many weights
    # Does the trace give right results?
    assert torch.allclose(m(x1, x2, w), traced(x1, x2, w))
Example #5
0
def test_gate():
    irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated = Irreps(
        "16x0o"), [torch.tanh], Irreps("32x0o"), [torch.tanh
                                                  ], Irreps("16x1e+16x1o")

    sc = _Sortcut(irreps_scalars, irreps_gates)
    assert_auto_jitable(sc)

    g = Gate(irreps_scalars, act_scalars, irreps_gates, act_gates,
             irreps_gated)
    assert_equivariant(g)
    assert_auto_jitable(g)
    assert_normalized(g)
Example #6
0
    def __init__(self, irreps_in, irreps_out, irreps_rel_pos, irreps_hidden, dim=3, update_pos=False,
                 recurrent=False):
        super(SEGNN, self).__init__(node_dim=-2, aggr="mean")  # <---- mean aggregation is important for node steering
        self.update_pos = update_pos
        self.dim = dim
        self.recurrent = recurrent
        self.irreps_rel_pos = irreps_rel_pos

        # Each layer within the message net is now steered via the rel_pos
        irreps_message_in = (irreps_in + irreps_in + Irreps("1x0e")).simplify()  # xi + xj + dist
        # self.message_net = nn.Sequential(O3LinearSwishGate(irreps_message_in, irreps_hidden, irreps_rel_pos),
        #                                  O3LinearSwishGate(irreps_hidden, irreps_hidden))
        self.message_layer_1 = O3LinearSwishGate(irreps_message_in, irreps_hidden, irreps_rel_pos)
        self.message_layer_2 = O3LinearSwishGate(irreps_hidden, irreps_out, irreps_rel_pos)

        # Each layer within the update net is now also steered via a distribution on the sphere by taking the average
        # over all neighbor rel_pos of the to-be-updated node
        irreps_update_in = (irreps_in + irreps_hidden).simplify()
        # self.update_net = nn.Sequential(O3LinearSwishGate(irreps_update_in, irreps_hidden, irreps_rel_pos),
        #                                 O3Linear(irreps_hidden, irreps_out))
        self.update_layer_1 = O3LinearSwishGate(irreps_update_in, irreps_hidden, irreps_rel_pos)
        self.update_layer_2 = O3Linear(irreps_hidden, irreps_out, irreps_rel_pos)

        if self.update_pos:  # TODO: currently not updated...
            hidden_features = 128
            self.pos_net = nn.Sequential(nn.Linear(hidden_features, hidden_features),
                                         Swish(),
                                         nn.Linear(hidden_features, dim))
Example #7
0
def BalancedIrreps(lmax, vec_dim, sh_type = True):
    irrep_spec = "0e"
    for l in range(1, lmax + 1):
        if sh_type:
            irrep_spec +=  " + {0}".format(l) + ('e' if ( l % 2) == 0 else 'o')
        else:
            irrep_spec += " + {0}e + {0}o".format(l)
    irrep_spec_split = irrep_spec.split(" + ")
    dims = [int(irrep[0]) * 2 + 1 for irrep in irrep_spec_split]
    # Compute ratios
    ratios = [1 / dim for dim in dims]
    # Determine how many copies per irrep
    irrep_copies = [int(vec_dim * r / len(ratios)) for r in ratios]
    # Determine the current effective irrep sizes
    irrep_dims = [n * dim for (n, dim) in zip(irrep_copies, dims)]
    # Add trivial irreps until the desired size is reached
    irrep_copies[0] += vec_dim - sum(irrep_dims)

    # Convert to string
    str_out = ''
    for (spec, dim) in zip(irrep_spec_split, irrep_copies):
        str_out += str(dim) + 'x' + spec
        str_out += ' + '
    str_out = str_out[:-3]
    # Generate the irrep
    #print('Determined irrep type:', str_out)
    return Irreps(str_out)
Example #8
0
    def __init__(self, irreps_in1, irreps_out, irreps_in2 = None) -> None:
        # For the gate the output of the linear needs to have an extra number of scalar irreps equal to the amount of
        # non scalar irreps:
        # The first type is assumed to be scalar and passed through the activation
        irreps_g_scalars = Irreps(str(irreps_out[0]))
        # The remaining types are gated
        irreps_g_gate = Irreps("{}x0e".format(irreps_out.num_irreps - irreps_g_scalars.num_irreps))
        irreps_g_gated = Irreps(str(irreps_out[1:]))
        # So the gate needs the following irrep as input, this is the output irrep of the tensor product
        irreps_g = (irreps_g_scalars + irreps_g_gate + irreps_g_gated).simplify()

        # Build the layers
        super(O3TensorProductSwishGate, self).__init__(irreps_in1, irreps_g, irreps_in2)
        if irreps_g_gated.num_irreps > 0:
            self.gate = Gate(irreps_g_scalars, [Swish()], irreps_g_gate, [torch.sigmoid], irreps_g_gated)
        else:
            self.gate = Swish()
Example #9
0
def _get_io_irreps(func, irreps_in=None, irreps_out=None):
    """Preprocess or, if not given, try to infer the I/O irreps for ``func``."""
    SPECIAL_VALS = ['cartesian_points', None]

    if (irreps_in is None or irreps_out is None) and isinstance(func, torch.jit.ScriptModule):
        warnings.warn(
            "Asking to infer irreps in/out of a compiled TorchScript module. This is unreliable, please provide `irreps_in` and `irreps_out` explicitly."
        )

    if irreps_in is None:
        if hasattr(func, 'irreps_in'):
            irreps_in = func.irreps_in  # gets checked for type later
        elif hasattr(func, 'irreps_in1'):
            irreps_in = [func.irreps_in1, func.irreps_in2]
        else:
            raise ValueError("Cannot infer irreps_in for %r; provide them explicitly" % func)
    if irreps_out is None:
        if hasattr(func, 'irreps_out'):
            irreps_out = func.irreps_out  # gets checked for type later
        else:
            raise ValueError("Cannot infer irreps_out for %r; provide them explicitly" % func)

    if isinstance(irreps_in, Irreps) or irreps_in in SPECIAL_VALS:
        irreps_in = [irreps_in]
    elif isinstance(irreps_in, list):
        irreps_in = [i if i in SPECIAL_VALS else Irreps(i) for i in irreps_in]
    else:
        if isinstance(irreps_in, tuple) and not isinstance(irreps_in, Irreps):
            warnings.warn(
                f"Module {func} had irreps_in of type tuple but not Irreps; ambiguous whether the tuple should be interpreted as a tuple representing a single Irreps or a tuple of objects each to be converted to Irreps. Assuming the former. If the latter, use a list."
            )
        irreps_in = [Irreps(irreps_in)]

    if isinstance(irreps_out, Irreps) or irreps_out in SPECIAL_VALS:
        irreps_out = [irreps_out]
    elif isinstance(irreps_out, list):
        irreps_out = [i if i in SPECIAL_VALS else Irreps(i) for i in irreps_out]
    else:
        if isinstance(irreps_in, tuple) and not isinstance(irreps_in, Irreps):
            warnings.warn(
                f"Module {func} had irreps_out of type tuple but not Irreps; ambiguous whether the tuple should be interpreted as a tuple representing a single Irreps or a tuple of objects each to be converted to Irreps. Assuming the former. If the latter, use a list."
            )
        irreps_out = [Irreps(irreps_out)]

    return irreps_in, irreps_out
Example #10
0
def WeightBalancedIrreps(irreps_in1_scalar, irreps_in2, sh = True):
    """
    Determines an irreps_in1 type of order irreps_in2.lmax that when used in a tensor product
    irreps_in1 x irreps_in2 -> irreps_in1
    would have the same number of weights as for a standard linear layer, e.g. a tensor product
    irreps_in1_scalar x "1x0e" -> irreps_in1_scaler
    """
    n = 1
    lmax = irreps_in2.lmax
    irreps_in1 = (Irreps.spherical_harmonics(lmax) * n).sort().irreps.simplify() if sh else BalancedIrreps(lmax, n)
    weight_numel1 = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_in1).weight_numel
    weight_numel_scalar = FullyConnectedTensorProduct(irreps_in1_scalar, Irreps("1x0e"), irreps_in1_scalar).weight_numel
    while weight_numel1 < weight_numel_scalar:  # TODO: somewhat suboptimal implementation...
        n += 1
        irreps_in1 = (Irreps.spherical_harmonics(lmax) * n).sort().irreps.simplify() if sh else BalancedIrreps(lmax, n)
        weight_numel1 = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_in1).weight_numel
    print('Determined irrep type:', irreps_in1)

    return Irreps(irreps_in1)
Example #11
0
def test_equivariance(lmax, res_b, res_a):
    m = FromS2Grid((res_b, res_a), lmax)
    k = ToS2Grid(lmax, (res_b, res_a))

    def f(x):
        y = k(x)
        y = y.exp()
        return m(y)

    f.irreps_in = f.irreps_out = Irreps.spherical_harmonics(lmax)

    assert_equivariant(f)
Example #12
0
def test_specialized_code(normalization, mode, weighted, float_tolerance):
    irreps_in1 = Irreps('4x0e + 4x1e + 4x2e')
    irreps_in2 = Irreps('5x0e + 5x1e + 5x2e')
    irreps_out = Irreps('6x0e + 6x1e + 6x2e')

    if mode == 'uvu':
        irreps_out = irreps_in1
    elif mode == 'uvv':
        irreps_out = irreps_in2
    elif mode == 'uuu':
        irreps_in2 = irreps_in1
        irreps_out = irreps_in1
    elif mode == 'uuw':
        irreps_in2 = irreps_in1
        # When unweighted, uuw is a plain sum over u and requires an output mul of 1
        if not weighted:
            irreps_out = Irreps([(1, ir) for _, ir in irreps_out])

    ins = [
        (0, 0, 0, mode, weighted, 1.0),
        (0, 1, 1, mode, weighted, 1.0),
        (1, 0, 1, mode, weighted, 1.0),
        (1, 1, 0, mode, weighted, 1.0),
        (1, 1, 1, mode, weighted, 1.0),
        (0, 2, 2, mode, weighted, 1.0),
        (2, 0, 2, mode, weighted, 1.0),
        (2, 2, 0, mode, weighted, 1.0),
        (2, 1, 1, mode, weighted, 1.0),
    ]
    tp1 = TensorProduct(irreps_in1,
                        irreps_in2,
                        irreps_out,
                        ins,
                        normalization=normalization,
                        _specialized_code=False)
    tp2 = TensorProduct(irreps_in1,
                        irreps_in2,
                        irreps_out,
                        ins,
                        normalization=normalization,
                        _specialized_code=True)
    with torch.no_grad():
        tp2.weight[:] = tp1.weight

    x = irreps_in1.randn(3, -1)
    y = irreps_in2.randn(3, -1)
    assert (tp1(x, y) - tp2(x, y)).abs().max() < float_tolerance
    assert (tp1.right(y) - tp2.right(y)).abs().max() < float_tolerance
Example #13
0
    def __init__(self, irreps_in1, irreps_out, irreps_in2=None, tp_rescale=True) -> None:
        super().__init__()

        self.irreps_in1 = irreps_in1
        self.irreps_out = irreps_out
        # Init irreps_in2
        if irreps_in2 == None:
            self.irreps_in2_provided = False
            self.irreps_in2 = Irreps("1x0e")
        else:
            self.irreps_in2_provided = True
            self.irreps_in2 = irreps_in2
        self.tp_rescale = tp_rescale

        # Build the layers
        self.tp = FullyConnectedTensorProduct(
            irreps_in1=self.irreps_in1,
            irreps_in2=self.irreps_in2,
            irreps_out=self.irreps_out, shared_weights=True, normalization='component')

        # For each zeroth order output irrep we need a bias
        # So first determine the order for each output tensor and their dims
        self.irreps_out_orders = [int(irrep_str[-2]) for irrep_str in str(irreps_out).split('+')]
        self.irreps_out_dims = [int(irrep_str.split('x')[0]) for irrep_str in str(irreps_out).split('+')]
        self.irreps_out_slices = irreps_out.slices()
        # Store tuples of slices and corresponding biases in a list
        self.biases = []
        self.biases_slices = []
        self.biases_slice_idx = []
        for slice_idx in range(len(self.irreps_out_orders)):
            if self.irreps_out_orders[slice_idx] == 0:
                out_slice = irreps_out.slices()[slice_idx]
                out_bias = torch.nn.Parameter(
                    torch.zeros(self.irreps_out_dims[slice_idx], dtype=self.tp.weight.dtype))
                self.biases += [out_bias]
                self.biases_slices += [out_slice]
                self.biases_slice_idx += [slice_idx]
        self.biases = torch.nn.ParameterList(self.biases)

        # Initialize the correction factors
        self.slices_sqrt_k = {}

        # Initialize similar to the torch.nn.Linear
        self.tensor_product_init()
Example #14
0
def test_weight_view_for_instruction():
    irreps_in1 = Irreps("1e + 2e + 3x3o")
    irreps_in2 = Irreps("1e + 2e + 3x3o")
    irreps_out = Irreps("1e + 2e + 3x3o")
    x1 = irreps_in1.randn(2, -1)
    x2 = irreps_in2.randn(2, -1)
    m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out)

    # Find all paths to the first output
    ins_idexes = [i for i, ins in enumerate(m.instructions) if ins.i_out == 0]
    with torch.no_grad():
        for i in ins_idexes:
            m.weight_view_for_instruction(i).zero_()

    out = m(x1, x2)
    assert torch.all(out[:, :1] == 0.0)
    assert torch.any(out[:, 1:] > 0.0)
Example #15
0
def test_weight_views():
    irreps_in1 = Irreps("1e + 2e + 3x3o")
    irreps_in2 = Irreps("1e + 2e + 3x3o")
    irreps_out = Irreps("1e + 2e + 3x3o")
    batchdim = 3
    x1 = irreps_in1.randn(batchdim, -1)
    x2 = irreps_in2.randn(batchdim, -1)
    # shared weights
    m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out)
    with torch.no_grad():
        for w in m.weight_views():
            w.zero_()
    assert torch.all(m(x1, x2) == 0.0)

    # unshared weights
    m = FullyConnectedTensorProduct(irreps_in1,
                                    irreps_in2,
                                    irreps_out,
                                    shared_weights=False)
    weights = torch.randn(batchdim, m.weight_numel)
    with torch.no_grad():
        for w in m.weight_views(weights):
            w.zero_()
    assert torch.all(m(x1, x2, weights) == 0.0)
Example #16
0
    def __init__(
        self,
        num_atoms,  # not used
        bond_feat_dim,  # not used
        num_targets,  # not used
        in_features=9,
        out_features=1,
        hidden_features=256,
        N=7,
        dim=3,
        lmax_h=2,
        lmax_pos=2,
        update_pos=False,
        recurrent=True,
        regress_forces=False,
        use_pbc=True,
        otf_graph=False
    ):

        super(SEGNNModel, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.hidden_features = hidden_features
        self.N = N
        self.regress_forces = regress_forces
        self.otf_graph = otf_graph
        self.use_pbc = use_pbc
        self.update_pos = update_pos
        self.recurrent = recurrent
        self.dim = dim
        self.lmax_h = lmax_h
        self.lmax_pos = lmax_pos

        # Irreps for the node features
        node_in_irreps_scalar = Irreps("{0}x0e".format(self.in_features))  # This is the type of the input
        #node_hidden_irreps = BalancedIrreps(self.lmax_h, self.hidden_features)  # This is the type on the hidden reps
        node_hidden_irreps_scalar = Irreps("{0}x0e".format(self.hidden_features))  # For the output layers
        node_out_irreps_scalar = Irreps("{0}x0e".format(self.out_features))  # This is the type on the output

        # Irreps for the edge and node attributes
        attr_irreps = Irreps.spherical_harmonics(self.lmax_pos)
        self.attr_irreps = attr_irreps

        node_hidden_irreps = WeightBalancedIrreps(node_hidden_irreps_scalar, attr_irreps, False)  # True: copies of sh

        # Network for computing the node attributes
        self.node_attribute_net = NodeAttributeNetwork()

        # The embedding layer (acts point-wise, no orientation information so only use trivial/scalar irreps)
        self.embedding_layer_1 = O3TensorProductSwishGate(node_in_irreps_scalar,  # in
                                                          node_hidden_irreps,  # out
                                                          attr_irreps)  # steerable attribute
        self.embedding_layer_2 = O3TensorProductSwishGate(node_hidden_irreps,  # in
                                                          node_hidden_irreps,  # out
                                                          attr_irreps)  # steerable attribute
        self.embedding_layer_3 = O3TensorProduct(node_hidden_irreps,  # in
                                                 node_hidden_irreps,  # out
                                                 attr_irreps)  # steerable attribute

        # The main layers
        self.layers = []
        for i in range(self.N):
            self.layers.append(SEGNN(node_hidden_irreps,  # in
                                     node_hidden_irreps,  # hidden
                                     node_hidden_irreps,  # out
                                     attr_irreps,  # steerable attribute
                                     update_pos=self.update_pos, recurrent=self.recurrent))
        self.layers = nn.ModuleList(self.layers)

        # The output network (again via point-wise operation via scalar irreps)
        self.head_pre_pool_layer_1 = O3TensorProductSwishGate(node_hidden_irreps,  # in
                                                              node_hidden_irreps_scalar,  # out
                                                              attr_irreps)  # steerable attribute
        self.head_pre_pool_layer_2 = O3TensorProduct(node_hidden_irreps_scalar,  # in
                                                     node_hidden_irreps_scalar)  # out
        self.head_post_pool_layer_1 = O3TensorProductSwishGate(node_hidden_irreps_scalar,  # in
                                                               node_hidden_irreps_scalar)  # out
        self.head_post_pool_layer_2 = O3TensorProduct(node_hidden_irreps_scalar,  # in
                                                      node_out_irreps_scalar)  # out


        # read atom map
        atom_map = torch.zeros(101, 9)
        for i in range(101):
            atom_map[i] = torch.tensor(CONTINUOUS_EMBEDDINGS[i])

        # normalize along each dimension
        atom_map[0] = np.nan
        atom_map_notnan = atom_map[atom_map[:, 0] == atom_map[:, 0]]
        atom_map_min = torch.min(atom_map_notnan, dim=0)[0]
        atom_map_max = torch.max(atom_map_notnan, dim=0)[0]
        atom_map_gap = atom_map_max - atom_map_min
        # squash to [0,1]
        atom_map = (atom_map - atom_map_min.view(1, -1)) / atom_map_gap.view(1, -1)
        self.atom_map = torch.nn.Parameter(atom_map, requires_grad=False)

        # read atom radii
        atom_radii = torch.zeros(101)
        for i in range(101):
            atom_radii[i] = ATOMIC_RADII[i]
        atom_radii = atom_radii / 100
        self.atom_radii = nn.Parameter(atom_radii, requires_grad=False)
Example #17
0
def codegen_tensor_product(
    irreps_in1: o3.Irreps,
    in1_var: List[float],
    irreps_in2: o3.Irreps,
    in2_var: List[float],
    irreps_out: o3.Irreps,
    out_var: List[float],
    instructions: List[Instruction],
    normalization: str = 'component',
    shared_weights: bool = False,
    specialized_code: bool = True,
    optimize_einsums: bool = True,
) -> Tuple[fx.GraphModule, fx.GraphModule]:
    graph_out = fx.Graph()
    graph_right = fx.Graph()

    # = Function definitions =
    x1s_out = fx.Proxy(graph_out.placeholder('x1', torch.Tensor))
    x2s_out = fx.Proxy(graph_out.placeholder('x2', torch.Tensor))
    ws_out = fx.Proxy(graph_out.placeholder('w', torch.Tensor))

    x2s_right = fx.Proxy(graph_right.placeholder('x2', torch.Tensor))
    ws_right = fx.Proxy(graph_right.placeholder('w', torch.Tensor))

    empty_out = fx.Proxy(
        graph_out.call_function(torch.empty, ((), ), dict(device='cpu')))
    empty_right = fx.Proxy(
        graph_right.call_function(torch.empty, ((), ), dict(device='cpu')))
    if shared_weights:
        size_out = torch.broadcast_tensors(
            empty_out.expand(x1s_out.shape[:-1]),
            empty_out.expand(x2s_out.shape[:-1]))[0].shape
        size_right = x2s_right.shape[:-1]
    else:
        size_out = torch.broadcast_tensors(
            empty_out.expand(x1s_out.shape[:-1]),
            empty_out.expand(x2s_out.shape[:-1]),
            empty_out.expand(ws_out.shape[:-1]))[0].shape
        size_right = torch.broadcast_tensors(
            empty_right.expand(x2s_right.shape[:-1]),
            empty_right.expand(ws_right.shape[:-1]))[0].shape

    # = Short-circut for zero dimensional =
    # We produce no code for empty instructions
    instructions = [ins for ins in instructions if 0 not in ins.path_shape]

    if len(instructions) == 0:
        out_out = x1s_out.new_zeros(size_out + (irreps_out.dim, ))
        out_right = x2s_right.new_zeros(size_right + (
            irreps_in1.dim,
            irreps_out.dim,
        ))

        graph_out.output(out_out.node, torch.Tensor)
        graph_right.output(out_right.node, torch.Tensor)
        # Short circut
        return (fx.GraphModule({}, graph_out, "tp_forward"),
                fx.GraphModule({}, graph_right, "tp_right"))

    # = Broadcast inputs =
    if shared_weights:
        x1s_out, x2s_out = x1s_out.broadcast_to(
            size_out + (-1, )), x2s_out.broadcast_to(size_out + (-1, ))
    else:
        x1s_out, x2s_out, ws_out = x1s_out.broadcast_to(
            size_out + (-1, )), x2s_out.broadcast_to(
                size_out + (-1, )), ws_out.broadcast_to(size_out + (-1, ))
        x2s_right, ws_right = x2s_right.broadcast_to(
            size_right + (-1, )), ws_right.broadcast_to(size_right + (-1, ))

    outsize_out = size_out + (irreps_out.dim, )
    outsize_right = size_right + (
        irreps_in1.dim,
        irreps_out.dim,
    )

    x1s_out = x1s_out.reshape(-1, irreps_in1.dim)
    x2s_out = x2s_out.reshape(-1, irreps_in2.dim)
    x2s_right = x2s_right.reshape(-1, irreps_in2.dim)

    batch_out = x1s_out.shape[0]
    batch_right = x2s_right.shape[0]

    # = Determine number of weights and reshape weights ==
    weight_numel = sum(
        prod(ins.path_shape) for ins in instructions if ins.has_weight)
    if weight_numel > 0:
        ws_out = ws_out.reshape(-1, weight_numel)
        ws_right = ws_right.reshape(-1, weight_numel)
    del weight_numel

    # = book-keeping for wigners =
    w3j = []
    w3j_dict_out = dict()
    w3j_dict_right = dict()

    # = extract individual input irreps =
    # If only one input irrep, can avoid creating a view
    if len(irreps_in1) == 1:
        x1_list_out = [
            x1s_out.reshape(batch_out, irreps_in1[0].mul, irreps_in1[0].ir.dim)
        ]
    else:
        x1_list_out = [
            x1s_out[:, i].reshape(batch_out, mul_ir.mul, mul_ir.ir.dim)
            for i, mul_ir in zip(irreps_in1.slices(), irreps_in1)
        ]

    x2_list_out = []
    x2_list_right = []
    # If only one input irrep, can avoid creating a view
    if len(irreps_in2) == 1:
        x2_list_out.append(
            x2s_out.reshape(batch_out, irreps_in2[0].mul,
                            irreps_in2[0].ir.dim))
        x2_list_right.append(
            x2s_right.reshape(batch_right, irreps_in2[0].mul,
                              irreps_in2[0].ir.dim))
    else:
        for i, mul_ir in zip(irreps_in2.slices(), irreps_in2):
            x2_list_out.append(x2s_out[:, i].reshape(batch_out, mul_ir.mul,
                                                     mul_ir.ir.dim))
            x2_list_right.append(x2s_right[:,
                                           i].reshape(batch_right, mul_ir.mul,
                                                      mul_ir.ir.dim))

    # The einsum string index to prepend to the weights if the weights are not shared and have a batch dimension
    z = '' if shared_weights else 'z'

    # Cache of input irrep pairs whose outer products (xx) have already been computed
    xx_dict = dict()

    # Current index in the flat weight tensor
    flat_weight_index = 0

    out_list_out = []
    out_list_right = []

    for ins in instructions:
        mul_ir_in1 = irreps_in1[ins.i_in1]
        mul_ir_in2 = irreps_in2[ins.i_in2]
        mul_ir_out = irreps_out[ins.i_out]

        assert mul_ir_in1.ir.p * mul_ir_in2.ir.p == mul_ir_out.ir.p
        assert abs(mul_ir_in1.ir.l - mul_ir_in2.ir.l
                   ) <= mul_ir_out.ir.l <= mul_ir_in1.ir.l + mul_ir_in2.ir.l

        if mul_ir_in1.dim == 0 or mul_ir_in2.dim == 0 or mul_ir_out.dim == 0:
            continue

        alpha = ins.path_weight * out_var[ins.i_out] / sum(
            in1_var[i.i_in1] * in2_var[i.i_in2]
            for i in instructions if i.i_out == ins.i_out)

        # Open the profiler block
        name = f"{mul_ir_in1} x {mul_ir_in2} = {mul_ir_out} {ins.connection_mode} {ins.has_weight}"
        handle_out = graph_out.call_function(
            torch.ops.profiler._record_function_enter, (name, ))
        handle_right = graph_right.call_function(
            torch.ops.profiler._record_function_enter, (name, ))

        x1_out = x1_list_out[ins.i_in1]
        x2_out = x2_list_out[ins.i_in2]
        x2_right = x2_list_right[ins.i_in2]

        e1_right = fx.Proxy(
            graph_right.call_function(
                torch.eye, (mul_ir_in1.mul, ),
                dict(dtype=x2s_right.dtype.node,
                     device=x2s_right.device.node)))
        e2_right = fx.Proxy(
            graph_right.call_function(
                torch.eye, (mul_ir_in2.mul, ),
                dict(dtype=x2s_right.dtype.node,
                     device=x2s_right.device.node)))
        i1_right = fx.Proxy(
            graph_right.call_function(
                torch.eye, (mul_ir_in1.ir.dim, ),
                dict(dtype=x2s_right.dtype.node,
                     device=x2s_right.device.node)))

        assert ins.connection_mode in [
            'uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv'
        ]

        alpha = sqrt(
            alpha / {
                'uvw': (mul_ir_in1.mul * mul_ir_in2.mul),
                'uvu': mul_ir_in2.mul,
                'uvv': mul_ir_in1.mul,
                'uuw': mul_ir_in1.mul,
                'uuu': 1,
                'uvuv': 1,
            }[ins.connection_mode])

        if ins.has_weight:
            # Extract the weight from the flattened weight tensor
            w_out = ws_out[:, flat_weight_index:flat_weight_index +
                           prod(ins.path_shape)].reshape((
                               () if shared_weights else (-1, )) +
                                                         tuple(ins.path_shape))
            w_right = ws_right[:, flat_weight_index:flat_weight_index +
                               prod(ins.path_shape)].reshape(
                                   (() if shared_weights else (-1, )) +
                                   tuple(ins.path_shape))
            flat_weight_index += prod(ins.path_shape)

        # Construct the general xx in case this instruction isn't specialized
        # If this isn't used, the dead code will get removed
        key = (ins.i_in1, ins.i_in2, ins.connection_mode[:2])
        if key not in xx_dict:
            if ins.connection_mode[:2] == 'uv':
                xx_dict[key] = torch.einsum('zui,zvj->zuvij', x1_out, x2_out)
            if ins.connection_mode[:2] == 'uu':
                xx_dict[key] = torch.einsum('zui,zuj->zuij', x1_out, x2_out)
        xx = xx_dict[key]

        # Create a proxy & request for the relevant wigner w3j
        # If not used (because of specialized code), will get removed later.
        key = (mul_ir_in1.ir.l, mul_ir_in2.ir.l, mul_ir_out.ir.l)
        if key not in w3j:
            w3j_dict_out[key] = fx.Proxy(
                graph_out.get_attr(f"_w3j_{key[0]}_{key[1]}_{key[2]}"))
            w3j_dict_right[key] = fx.Proxy(
                graph_right.get_attr(f"_w3j_{key[0]}_{key[1]}_{key[2]}"))
            w3j.append(key)
        w3j_out = w3j_dict_out[key]
        w3j_right = w3j_dict_right[key]

        exp = {'component': 1, 'norm': -1}[normalization]

        if ins.connection_mode == 'uvw':
            assert ins.has_weight
            if specialized_code and key == (0, 0, 0):
                ein_out = torch.einsum(
                    f"{z}uvw,zu,zv->zw", w_out,
                    x1_out.reshape(batch_out, mul_ir_in1.dim),
                    x2_out.reshape(batch_out, mul_ir_in2.dim))
                ein_right = torch.einsum(
                    f"{z}uvw,zv->zuw", w_right,
                    x2_right.reshape(batch_right, mul_ir_in2.dim))
            elif specialized_code and mul_ir_in1.ir.l == 0:
                ein_out = torch.einsum(
                    f"{z}uvw,zu,zvj->zwj", w_out,
                    x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out)
                ein_right = torch.einsum(f"{z}uvw,zvi->zuwi", w_right,
                                         x2_right)
            elif specialized_code and mul_ir_in2.ir.l == 0:
                ein_out = torch.einsum(
                    f"{z}uvw,zui,zv->zwi", w_out, x1_out,
                    x2_out.reshape(batch_out, mul_ir_in2.dim))
                ein_right = torch.einsum(
                    f"{z}uvw,ij,zv->zuiwj", w_right, i1_right,
                    x2_right.reshape(batch_right, mul_ir_in2.dim))
            elif specialized_code and mul_ir_out.ir.l == 0:
                ein_out = torch.einsum(f"{z}uvw,zui,zvi->zw", w_out, x1_out,
                                       x2_out) / sqrt(mul_ir_in1.ir.dim)**exp
                ein_right = torch.einsum(f"{z}uvw,zvi->zuiw", w_right,
                                         x2_right) / sqrt(
                                             mul_ir_in1.ir.dim)**exp
            else:
                ein_out = torch.einsum(f"{z}uvw,ijk,zuvij->zwk", w_out,
                                       w3j_out, xx)
                ein_right = torch.einsum(f"{z}uvw,ijk,zvj->zuiwk", w_right,
                                         w3j_right, x2_right)
        if ins.connection_mode == 'uvu':
            assert mul_ir_in1.mul == mul_ir_out.mul
            if ins.has_weight:
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        f"{z}uv,zu,zv->zu", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}uv,uw,zv->zuw", w_right, e1_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uv,zu,zvj->zuj", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out)
                    ein_right = torch.einsum(f"{z}uv,uw,zvi->zuwi", w_right,
                                             e1_right, x2_right)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uv,zui,zv->zui", w_out, x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}uv,ij,uw,zv->zuiwj", w_right, i1_right, e1_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum(f"{z}uv,zui,zvi->zu", w_out, x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                    ein_right = torch.einsum(f"{z}uv,uw,zvi->zuiw", w_right,
                                             e1_right, x2_right) / sqrt(
                                                 mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum(f"{z}uv,ijk,zuvij->zuk", w_out,
                                           w3j_out, xx)
                    ein_right = torch.einsum(f"{z}uv,ijk,uw,zvj->zuiwk",
                                             w_right, w3j_right, e1_right,
                                             x2_right)
            else:
                # not so useful operation because v is summed
                ein_out = torch.einsum("ijk,zuvij->zuk", w3j_out, xx)
                ein_right = torch.einsum("ijk,uw,zvj->zuiwk", w3j_right,
                                         e1_right, x2_right)
        if ins.connection_mode == 'uvv':
            assert mul_ir_in2.mul == mul_ir_out.mul
            if ins.has_weight:
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        f"{z}uv,zu,zv->zv", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}uv,vw,zv->zuw", w_right, e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uv,zu,zvj->zvj", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out)
                    ein_right = torch.einsum(f"{z}uv,vw,zvi->zuwi", w_right,
                                             e2_right, x2_right)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uv,zui,zv->zvi", w_out, x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}uv,ij,vw,zv->zuiwj", w_right, i1_right, e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum(f"{z}uv,zui,zvi->zv", w_out, x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                    ein_right = torch.einsum(f"{z}uv,vw,zvi->zuiw", w_right,
                                             e2_right, x2_right) / sqrt(
                                                 mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum(f"{z}uv,ijk,zuvij->zvk", w_out,
                                           w3j_out, xx)
                    ein_right = torch.einsum(f"{z}uv,ijk,zvj->zuivk", w_right,
                                             w3j_right, x2_right)
            else:
                # not so useful operation because u is summed
                # only specialize out for this path
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        "zu,zv->zv", x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        "zu,zvj->zvj", x1_out.reshape(batch_out,
                                                      mul_ir_in1.dim), x2_out)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        "zui,zv->zvi", x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum("zui,zvi->zv", x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum("ijk,zuvij->zvk", w3j_out, xx)
                s2ones = fx.Proxy(
                    graph_right.call_function(
                        torch.ones, (mul_ir_in1.mul, ),
                        dict(device=x2_right.device.node,
                             dtype=x2_right.dtype.node)))
                ein_right = torch.einsum("u,ijk,zvj->zuivk", s2ones, w3j_right,
                                         x2_right)
        if ins.connection_mode == 'uuw':
            assert mul_ir_in1.mul == mul_ir_in2.mul
            if ins.has_weight:
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        f"{z}uw,zu,zu->zw", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uw,zu,zuj->zwj", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uw,zui,zu->zwi", w_out, x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum(f"{z}uw,zui,zui->zw", w_out, x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum(f"{z}uw,ijk,zuij->zwk", w_out,
                                           w3j_out, xx)
                # TODO: specialize right()
                ein_right = torch.einsum(f"{z}uw,ijk,zuj->zuiwk", w_right,
                                         w3j_right, x2_right)
            else:
                # equivalent to tp(x, y, 'uuu').sum('u')
                assert mul_ir_out.mul == 1
                ein_out = torch.einsum("ijk,zuij->zk", w3j_out, xx)
                ein_right = torch.einsum("ijk,zuj->zuik", w3j_right, x2_right)
        if ins.connection_mode == 'uuu':
            assert mul_ir_in1.mul == mul_ir_in2.mul == mul_ir_out.mul
            if ins.has_weight:
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        f"{z}u,zu,zu->zu", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}u,uw,zu->zuw", w_right, e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and key == (
                        1, 1, 1) and normalization == "component":
                    ein_out = torch.einsum(f"{z}u,zui->zui", w_out,
                                           torch.cross(x1_out, x2_out,
                                                       dim=2)) / sqrt(2)
                    # For cross product, use the general case right()
                    ein_right = torch.einsum(f"{z}u,ijk,uw,zuj->zuiwk",
                                             w_right, w3j_right, e1_right,
                                             x2_right)
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}u,zu,zuj->zuj", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out)
                    ein_right = torch.einsum(f"{z}u,uw,zui->zuwi", w_right,
                                             e2_right, x2_right)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}u,zui,zu->zui", w_out, x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}u,ij,uw,zu->zuiwj", w_right, i1_right, e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum(f"{z}u,zui,zui->zu", w_out, x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                    ein_right = torch.einsum(f"{z}u,uw,zui->zuiw", w_right,
                                             e2_right, x2_right) / sqrt(
                                                 mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum(f"{z}u,ijk,zuij->zuk", w_out,
                                           w3j_out, xx)
                    ein_right = torch.einsum(f"{z}u,ijk,uw,zuj->zuiwk",
                                             w_right, w3j_right, e1_right,
                                             x2_right)
            else:
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        "zu,zu->zu", x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        "uw,zu->zuw", e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and key == (
                        1, 1, 1) and normalization == "component":
                    ein_out = torch.cross(x1_out, x2_out,
                                          dim=2) * (1.0 / sqrt(2))
                    # For cross product, use the general case right()
                    ein_right = torch.einsum("ijk,uw,zuj->zuiwk", w3j_right,
                                             e1_right, x2_right)
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        "zu,zuj->zuj", x1_out.reshape(batch_out,
                                                      mul_ir_in1.dim), x2_out)
                    ein_right = torch.einsum("uw,zui->zuwi", e2_right,
                                             x2_right)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        "zui,zu->zui", x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        "ij,uw,zu->zuiwj", i1_right, e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum("zui,zui->zu", x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                    ein_right = torch.einsum("uw,zui->zuiw", e2_right,
                                             x2_right) / sqrt(
                                                 mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum("ijk,zuij->zuk", w3j_out, xx)
                    ein_right = torch.einsum("ijk,uw,zuj->zuiwk", w3j_right,
                                             e1_right, x2_right)
        if ins.connection_mode == 'uvuv':
            assert mul_ir_in1.mul * mul_ir_in2.mul == mul_ir_out.mul
            if ins.has_weight:
                # TODO implement specialized code
                ein_out = torch.einsum(f"{z}uv,ijk,zuvij->zuvk", w_out,
                                       w3j_out, xx)
                ein_right = torch.einsum(f"{z}uv,ijk,uw,zvj->zuiwvk", w_right,
                                         w3j_right, e1_right, x2_right)
            else:
                # TODO implement specialized code
                ein_out = torch.einsum("ijk,zuvij->zuvk", w3j_out, xx)
                ein_right = torch.einsum("ijk,uw,zvj->zuiwvk", w3j_right,
                                         e1_right, x2_right)

        ein_out = alpha * ein_out
        ein_right = alpha * ein_right

        out_list_out += [ein_out.reshape(batch_out, mul_ir_out.dim)]
        out_list_right += [
            ein_right.reshape(batch_right, mul_ir_in1.dim, mul_ir_out.dim)
        ]

        # Close the profiler block
        graph_out.call_function(torch.ops.profiler._record_function_exit,
                                (handle_out, ))
        graph_right.call_function(torch.ops.profiler._record_function_exit,
                                  (handle_right, ))

        # Remove unused w3js:
        if len(w3j_out.node.users) == 0 and len(w3j_right.node.users) == 0:
            del w3j[-1]
            # The w3j nodes are reshapes, so we have to remove them from the graph
            # Although they are dead code, they try to reshape to dimensions that don't exist
            # (since the corresponding w3js are not in w3j)
            # so they screw up the shape propagation, even though they would be removed later as dead code by TorchScript.
            graph_out.erase_node(w3j_dict_out.pop(key).node)
            graph_right.erase_node(w3j_dict_right.pop(key).node)

    # = Return the result =
    out_out = [
        _sum_tensors([
            out for ins, out in zip(instructions, out_list_out)
            if ins.i_out == i_out
        ],
                     shape=(batch_out, mul_ir_out.dim),
                     like=x1s_out)
        for i_out, mul_ir_out in enumerate(irreps_out) if mul_ir_out.mul > 0
    ]
    if len(out_out) > 1:
        out_out = torch.cat(out_out, dim=1)
    else:
        # Avoid an unnecessary copy in a size one torch.cat
        out_out = out_out[0]

    out_right = [
        torch.cat([
            _sum_tensors([
                out for ins, out in zip(instructions, out_list_right)
                if (ins.i_in1, ins.i_out) == (i_in1, i_out)
            ],
                         shape=(batch_right, mul_ir_in1.dim, mul_ir_out.dim),
                         like=x2s_right)
            for i_out, mul_ir_out in enumerate(irreps_out)
            if mul_ir_out.mul > 0
        ],
                  dim=2) for i_in1, mul_ir_in1 in enumerate(irreps_in1)
        if mul_ir_in1.mul > 0
    ]
    if len(out_right) > 1:
        out_right = torch.cat(out_right, dim=1)
    else:
        out_right = out_right[0]

    out_out = out_out.reshape(outsize_out)
    out_right = out_right.reshape(outsize_right)

    graph_out.output(out_out.node, torch.Tensor)
    graph_right.output(out_right.node, torch.Tensor)

    # check graphs
    graph_out.lint()
    graph_right.lint()

    # Make GraphModules
    wigner_mats = {}
    for l_1, l_2, l_out in w3j:
        wig = o3.wigner_3j(l_1, l_2, l_out)

        if normalization == 'component':
            wig *= (2 * l_out + 1)**0.5
        if normalization == 'norm':
            wig *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5

        wigner_mats[f"_w3j_{l_1}_{l_2}_{l_out}"] = wig

    # By putting the constants in a Module rather than a dict,
    # we force FX to copy them as buffers instead of as attributes.
    #
    # FX seems to have resolved this issue for dicts in 1.9, but we support all the way back to 1.8.0.
    constants_root = torch.nn.Module()
    for wkey, wmat in wigner_mats.items():
        constants_root.register_buffer(wkey, wmat)
    graphmod_out = fx.GraphModule(constants_root,
                                  graph_out,
                                  class_name="tp_forward")
    graphmod_right = fx.GraphModule(constants_root,
                                    graph_right,
                                    class_name="tp_right")

    # == Optimize ==
    # TODO: when eliminate_dead_code() is in PyTorch stable, use that
    if optimize_einsums:
        # Note that for our einsums, we can optimize _once_ for _any_ batch dimension
        # and still get the right path for _all_ batch dimensions.
        # This is because our einsums are essentially of the form:
        #    zuvw,ijk,zuvij->zwk    OR     uvw,ijk,zuvij->zwk
        # In the first case, all but one operands have the batch dimension
        #    => The first contraction gains the batch dimension
        #    => All following contractions have batch dimension
        #    => All possible contraction paths have cost that scales linearly in batch size
        #    => The optimal path is the same for all batch sizes
        # For the second case, this logic follows as long as the first contraction is not between the first two operands. Since those two operands do not share any indexes, contracting them first is a rare pathological case. See
        # https://github.com/dgasmith/opt_einsum/issues/158
        # for more details.
        #
        # TODO: consider the impact maximum intermediate result size on this logic
        #         \- this is the `memory_limit` option in opt_einsum
        # TODO: allow user to choose opt_einsum parameters?
        #
        # We use float32 and zeros to save memory and time, since opt_einsum_fx looks only at traced shapes, not values or dtypes.
        batchdim = 4
        example_inputs = (
            torch.zeros((batchdim, irreps_in1.dim)),
            torch.zeros((batchdim, irreps_in2.dim)),
            torch.zeros(
                1 if shared_weights else batchdim,
                flat_weight_index,
            ),
        )
        graphmod_out = jitable(
            optimize_einsums_full(graphmod_out, example_inputs))
        graphmod_right = jitable(
            optimize_einsums_full(graphmod_right, example_inputs[1:]))

    return graphmod_out, graphmod_right
Example #18
0
def main():
    parser = argparse.ArgumentParser(prog="tensor_product_benchmark")
    parser.add_argument("--jit", type=t_or_f, default=True)
    parser.add_argument("--irreps",
                        type=str,
                        default="8x0e + 8x1e + 8x2e + 8x3o")
    parser.add_argument("--irreps-in1", type=str, default=None)
    parser.add_argument("--irreps-in2", type=str, default=None)
    parser.add_argument("--irreps-out", type=str, default=None)
    parser.add_argument("--cuda", type=t_or_f, default=True)
    parser.add_argument("--backward", type=t_or_f, default=True)
    parser.add_argument("--opt-ein", type=t_or_f, default=True)
    parser.add_argument("--specialized-code", type=t_or_f, default=True)
    parser.add_argument("--elementwise", action='store_true')
    parser.add_argument("-n", type=int, default=1000)
    parser.add_argument("--batch", type=int, default=10)

    args = parser.parse_args()

    device = 'cuda' if (torch.cuda.is_available() and args.cuda) else 'cpu'
    args.cuda = device == 'cuda'

    print("======= Benchmark with settings: ======")
    for key, val in vars(args).items():
        print(f"{key:>18} : {val}")
    print("=" * 40)

    irreps_in1 = Irreps(args.irreps_in1 if args.irreps_in1 else args.irreps)
    irreps_in2 = Irreps(args.irreps_in2 if args.irreps_in2 else args.irreps)
    irreps_out = Irreps(args.irreps_out if args.irreps_out else args.irreps)

    if args.elementwise:
        tp = ElementwiseTensorProduct(irreps_in1,
                                      irreps_in2,
                                      _specialized_code=args.specialized_code,
                                      _optimize_einsums=args.opt_ein)
        if args.backward:
            print(
                "Elementwise TP has no weights, cannot backward. Setting --backward False."
            )
            args.backward = False
    else:
        tp = FullyConnectedTensorProduct(
            irreps_in1,
            irreps_in2,
            irreps_out,
            _specialized_code=args.specialized_code,
            _optimize_einsums=args.opt_ein)
    tp = tp.to(device=device)
    assert len(tp.instructions) > 0, "Bad irreps, no instructions"
    print(f"Tensor product: {tp}")
    print("Instructions:")
    for ins in tp.instructions:
        print(f"  {ins}")

    # from https://pytorch.org/docs/master/_modules/torch/utils/benchmark/utils/timer.html#Timer.timeit
    warmup = max(int(args.n // 100), 1)

    inputs = iter([(irreps_in1.randn(args.batch, -1).to(device=device),
                    irreps_in2.randn(args.batch, -1).to(device=device))
                   for _ in range(args.n + warmup)])

    # compile
    if args.jit:
        tp = compile(tp)

    print("starting...")

    # tanh() forces it to realize the grad as a full size matrix rather than expanded (stride 0) ones
    t = Timer(
        stmt=("tp.zero_grad()\n"
              "out = tp(*next(inputs))\n" +
              ("out.tanh().sum().backward()\n" if args.backward else '')),
        globals={
            'tp': tp,
            'inputs': inputs
        })

    perloop = t.timeit(args.n)

    print()
    print(perloop)
Example #19
0
def test_empty_irreps():
    tp = FullyConnectedTensorProduct('0e + 1e', Irreps([]), '0e + 1e')
    out = tp(torch.randn(1, 2, 4), torch.randn(2, 1, 0))
    assert out.shape == (2, 2, 4)
Example #20
0
    def __init__(
        self,
        num_atoms,  # not used
        bond_feat_dim,  # not used
        num_targets,  # not used
        in_features=9,
        out_features=1,
        hidden_features=256,
        N=7,
        dim=3,
        lmax_h=2,
        lmax_pos=2,
        update_pos=False,
        recurrent=True,
        regress_forces=False,
        use_pbc=True,
        otf_graph=False
    ):

        super(SEGNNModel, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.hidden_features = hidden_features
        self.N = N
        self.regress_forces = regress_forces
        self.otf_graph = otf_graph
        self.use_pbc = use_pbc
        self.update_pos = update_pos
        self.recurrent = recurrent
        self.dim = dim
        self.lmax_h = lmax_h
        self.lmax_pos = lmax_pos

        # The representations used in the model
        self.irreps_in = Irreps("{0}x0e".format(self.in_features))
        self.irreps_hidden = BalancedIrreps(self.lmax_h, self.hidden_features)
        self.irreps_hidden_scalar = Irreps("{0}x0e".format(self.hidden_features))
        self.irreps_out = Irreps("{0}x0e".format(self.out_features))
        self.irreps_rel_pos = Irreps.spherical_harmonics(self.lmax_pos)

        # The embedding layer (acts point-wise, no orientation information so only use trivial/scalar irreps)
        self.embedding = nn.Sequential(O3LinearSwishGate(self.irreps_in, self.irreps_hidden_scalar),
                                       O3Linear(self.irreps_hidden_scalar, self.irreps_hidden_scalar))
        # The intermediate layers
        self.layers = []
        # The first layer changes from scalar irreps to irreps of some max order (lmax_h)
        self.layers.append(SEGNN(self.irreps_hidden_scalar, self.irreps_hidden, self.irreps_rel_pos, self.irreps_hidden,
                                 update_pos=self.update_pos, recurrent=False))
        # Subsequent layers act on the irreps of some max order (lmax_h)
        for i in range(self.N - 2):
            self.layers.append(SEGNN(self.irreps_hidden, self.irreps_hidden, self.irreps_rel_pos, self.irreps_hidden,
                                     update_pos=self.update_pos, recurrent=self.recurrent))
        # The last layer of the SEGNN block converts back to scalar irreps
        self.layers.append(
            SEGNN(self.irreps_hidden, self.irreps_hidden_scalar, self.irreps_rel_pos, self.irreps_hidden_scalar,
                  update_pos=self.update_pos, recurrent=False))
        # To ModuleList
        self.layers = nn.ModuleList(self.layers)

        # The output network (again via point-wise operation via scalar irreps)
        self.head_pre_pool = nn.Sequential(O3LinearSwishGate(self.irreps_hidden_scalar, self.irreps_hidden_scalar),
                                           O3Linear(self.irreps_hidden_scalar, self.irreps_hidden_scalar))
        self.head_post_pool = nn.Sequential(O3LinearSwishGate(self.irreps_hidden_scalar, self.irreps_hidden_scalar),
                                            O3Linear(self.irreps_hidden_scalar, self.irreps_out))


        # read atom map
        atom_map = torch.zeros(101, 9)
        for i in range(101):
            atom_map[i] = torch.tensor(CONTINUOUS_EMBEDDINGS[i])

        # normalize along each dimension
        atom_map[0] = np.nan
        atom_map_notnan = atom_map[atom_map[:, 0] == atom_map[:, 0]]
        atom_map_min = torch.min(atom_map_notnan, dim=0)[0]
        atom_map_max = torch.max(atom_map_notnan, dim=0)[0]
        atom_map_gap = atom_map_max - atom_map_min
        # squash to [0,1]
        atom_map = (atom_map - atom_map_min.view(1, -1)) / atom_map_gap.view(1, -1)
        self.atom_map = torch.nn.Parameter(atom_map, requires_grad=False)
Example #21
0
def _codegen_linear(
    irreps_in: o3.Irreps,
    irreps_out: o3.Irreps,
    instructions: List[Instruction],
    biases: List[bool],
    f_in: Optional[int] = None,
    f_out: Optional[int] = None,
    shared_weights: bool = False,
    optimize_einsums: bool = True,
) -> Tuple[fx.GraphModule, int, int]:
    graph_out = fx.Graph()

    # = Function definitions =
    x = fx.Proxy(graph_out.placeholder('x', torch.Tensor))
    ws = fx.Proxy(graph_out.placeholder('w', torch.Tensor))
    bs = fx.Proxy(graph_out.placeholder('b', torch.Tensor))

    if f_in is None:
        size = x.shape[:-1]
        outsize = size + (irreps_out.dim, )
    else:
        size = x.shape[:-2]
        outsize = size + (
            f_out,
            irreps_out.dim,
        )

    bias_numel = sum(mul_ir.dim for bias, mul_ir in zip(biases, irreps_out)
                     if bias)
    if bias_numel > 0:
        if f_out is None:
            bs = bs.reshape(-1, bias_numel)
        else:
            bs = bs.reshape(-1, f_out, bias_numel)

    # = Short-circut for nothing to do =
    # We produce no code for empty instructions
    instructions = [ins for ins in instructions if 0 not in ins.path_shape]

    if len(instructions) == 0 and bias_numel == 0:
        out = x.new_zeros(outsize)

        graph_out.output(out.node, torch.Tensor)
        # Short circut
        # 0 is weight_numel
        return fx.GraphModule({}, graph_out, "linear_forward"), 0, 0

    if f_in is None:
        x = x.reshape(-1, irreps_in.dim)
    else:
        x = x.reshape(-1, f_in, irreps_in.dim)
    batch_out = x.shape[0]

    out_bias_list = []
    bias_index = 0
    for bias, mul_ir_out in zip(biases, irreps_out):
        if bias:
            if sum(biases) == 1:
                b = bs
            else:
                b = bs.narrow(-1, bias_index, mul_ir_out.dim)
                bias_index += mul_ir_out.dim
            out_bias_list += [[
                b.expand(batch_out, -1) if f_out is None else b.expand(
                    batch_out, f_out, -1)
            ]]
        else:
            out_bias_list += [[]]

    weight_numel = sum(prod(ins.path_shape) for ins in instructions)
    if weight_numel > 0:
        ws = ws.reshape(-1, weight_numel) if f_in is None else ws.reshape(
            -1, f_in, f_out, weight_numel)

    # = extract individual input irreps =
    if len(irreps_in) == 1:
        x_list = [
            x.reshape(batch_out, *(() if f_in is None else (f_in, )),
                      irreps_in[0].mul, irreps_in[0].ir.dim)
        ]
    else:
        x_list = [
            x.narrow(-1, i.start,
                     mul_ir.dim).reshape(batch_out,
                                         *(() if f_in is None else (f_in, )),
                                         mul_ir.mul, mul_ir.ir.dim)
            for i, mul_ir in zip(irreps_in.slices(), irreps_in)
        ]

    z = '' if shared_weights else 'z'

    flat_weight_index = 0

    out_list = []

    for ins in instructions:
        mul_ir_in = irreps_in[ins.i_in]
        mul_ir_out = irreps_out[ins.i_out]

        # Short-circut for empty irreps
        if mul_ir_in.dim == 0 or mul_ir_out.dim == 0:
            continue

        # Extract the weight from the flattened weight tensor
        path_nweight = prod(ins.path_shape)
        if len(instructions) == 1:
            # Avoid unnecessary view when there is only one weight
            w = ws
        else:
            w = ws.narrow(-1, flat_weight_index, path_nweight)
        w = w.reshape((() if shared_weights else (-1, )) +
                      (() if f_in is None else (f_in, f_out)) + ins.path_shape)
        flat_weight_index += path_nweight

        if f_in is None:
            ein_out = torch.einsum(f"{z}uw,zui->zwi", w, x_list[ins.i_in])
        else:
            ein_out = torch.einsum(f"{z}xyuw,zxui->zywi", w, x_list[ins.i_in])
        alpha = 1.0 / math.sqrt((f_in or 1) * mul_ir_in.mul *
                                sum(1 if other_ins.i_out == ins.i_out else 0
                                    for other_ins in instructions))
        ein_out = alpha * ein_out

        out_list += [
            ein_out.reshape(batch_out, *(() if f_out is None else (f_out, )),
                            mul_ir_out.dim)
        ]

    # = Return the result =
    out = [
        _sum_tensors([
            out
            for ins, out in zip(instructions, out_list) if ins.i_out == i_out
        ] + out_bias_list[i_out],
                     shape=(batch_out, *(() if f_out is None else
                                         (f_out, )), mul_ir_out.dim),
                     like=x) for i_out, mul_ir_out in enumerate(irreps_out)
        if mul_ir_out.mul > 0
    ]
    if len(out) > 1:
        out = torch.cat(out, dim=-1)
    else:
        out = out[0]

    out = out.reshape(outsize)

    graph_out.output(out.node, torch.Tensor)

    # check graphs
    graph_out.lint()

    graphmod_out = fx.GraphModule({}, graph_out, "linear_forward")

    # TODO: when eliminate_dead_code() is in PyTorch stable, use that
    if optimize_einsums:
        # See _tensor_product/_codegen.py for notes
        batchdim = 4
        example_inputs = (
            torch.zeros((batchdim, *(() if f_in is None else
                                     (f_in, )), irreps_in.dim)),
            torch.zeros(
                1 if shared_weights else batchdim,
                f_in or 1,
                f_out or 1,
                weight_numel,
            ),
            torch.zeros(
                1 if shared_weights else batchdim,
                f_out or 1,
                bias_numel,
            ),
        )
        graphmod_out = jitable(
            optimize_einsums_full(graphmod_out, example_inputs))

    return graphmod_out, weight_numel, bias_numel
def main():
    parser = argparse.ArgumentParser(prog="tensor_product_benchmark")
    parser.add_argument("--jit", type=t_or_f, default=True)
    parser.add_argument("--irreps-in1",
                        type=str,
                        default="8x0e + 8x1e + 8x2e + 8x3e")
    parser.add_argument("--irreps-in2",
                        type=str,
                        default="8x0e + 8x1e + 8x2e + 8x3e")
    parser.add_argument("--irreps-out",
                        type=str,
                        default="8x0e + 8x1e + 8x2e + 8x3e")
    parser.add_argument("--cuda", type=t_or_f, default=True)
    parser.add_argument("--backward", type=t_or_f, default=True)
    parser.add_argument("--opt-ein", type=t_or_f, default=True)
    parser.add_argument("--specialized-code", type=t_or_f, default=True)
    parser.add_argument("-w", type=int, default=10)
    parser.add_argument("-n", type=int, default=3)
    parser.add_argument("--batch", type=int, default=10)

    args = parser.parse_args()

    device = 'cuda' if (torch.cuda.is_available() and args.cuda) else 'cpu'
    args.cuda = device == 'cuda'

    if args.cuda:
        # Workaround for CUDA driver issues
        # See https://github.com/pytorch/pytorch/issues/60158#issuecomment-866294291
        with torch.profiler.profile() as _:
            pass

    print("======= Benchmark with settings: ======")
    for key, val in vars(args).items():
        print(f"{key:>18} : {val}")
    print("=" * 40)

    irreps_in1 = Irreps(args.irreps_in1)
    irreps_in2 = Irreps(args.irreps_in2)
    irreps_out = Irreps(args.irreps_out)
    tp = FullyConnectedTensorProduct(irreps_in1,
                                     irreps_in2,
                                     irreps_out,
                                     _specialized_code=args.specialized_code,
                                     _optimize_einsums=args.opt_ein)
    tp = tp.to(device=device)

    inputs = [(irreps_in1.randn(args.batch, -1).to(device=device),
               irreps_in2.randn(args.batch, -1).to(device=device))
              for _ in range(1 + args.w + args.n)]
    if args.backward:
        for tmp in inputs:
            for t in tmp:
                t.requires_grad_(True)
    inputs = iter(inputs)

    # compile
    if args.jit:
        print("JITing...")
        tp = compile(tp)

    print("starting...")

    called_num = [0]

    def trace_handler(p):
        print(p.key_averages().table(sort_by="self_cuda_time_total",
                                     row_limit=-1))
        p.export_chrome_trace("test_trace_" + str(called_num[0]) + ".json")
        called_num[0] += 1

    with torch.profiler.profile(activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA
    ],
                                schedule=torch.profiler.schedule(
                                    wait=1, warmup=args.w, active=args.n),
                                on_trace_ready=trace_handler) as p:
        for _ in range(1 + args.w + args.n):
            out = tp(*next(inputs))
            if args.backward:
                # tanh() forces it to realize the grad as a full size matrix rather than expanded (stride 0) ones
                out.tanh().sum().backward()
            p.step()