def test_weight_views(): irreps_in1 = Irreps("1e + 2e + 3x3o") irreps_in2 = Irreps("1e + 2e + 3x3o") irreps_out = Irreps("1e + 2e + 3x3o") batchdim = 3 x1 = irreps_in1.randn(batchdim, -1) x2 = irreps_in2.randn(batchdim, -1) # shared weights m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out) with torch.no_grad(): for w in m.weight_views(): w.zero_() assert torch.all(m(x1, x2) == 0.0) # unshared weights m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, shared_weights=False) weights = torch.randn(batchdim, m.weight_numel) with torch.no_grad(): for w in m.weight_views(weights): w.zero_() assert torch.all(m(x1, x2, weights) == 0.0)
class O3TensorProduct(torch.nn.Module): def __init__(self, irreps_in1, irreps_out, irreps_in2=None, tp_rescale=True) -> None: super().__init__() self.irreps_in1 = irreps_in1 self.irreps_out = irreps_out # Init irreps_in2 if irreps_in2 == None: self.irreps_in2_provided = False self.irreps_in2 = Irreps("1x0e") else: self.irreps_in2_provided = True self.irreps_in2 = irreps_in2 self.tp_rescale = tp_rescale # Build the layers self.tp = FullyConnectedTensorProduct( irreps_in1=self.irreps_in1, irreps_in2=self.irreps_in2, irreps_out=self.irreps_out, shared_weights=True, normalization='component') # For each zeroth order output irrep we need a bias # So first determine the order for each output tensor and their dims self.irreps_out_orders = [int(irrep_str[-2]) for irrep_str in str(irreps_out).split('+')] self.irreps_out_dims = [int(irrep_str.split('x')[0]) for irrep_str in str(irreps_out).split('+')] self.irreps_out_slices = irreps_out.slices() # Store tuples of slices and corresponding biases in a list self.biases = [] self.biases_slices = [] self.biases_slice_idx = [] for slice_idx in range(len(self.irreps_out_orders)): if self.irreps_out_orders[slice_idx] == 0: out_slice = irreps_out.slices()[slice_idx] out_bias = torch.nn.Parameter( torch.zeros(self.irreps_out_dims[slice_idx], dtype=self.tp.weight.dtype)) self.biases += [out_bias] self.biases_slices += [out_slice] self.biases_slice_idx += [slice_idx] self.biases = torch.nn.ParameterList(self.biases) # Initialize the correction factors self.slices_sqrt_k = {} # Initialize similar to the torch.nn.Linear self.tensor_product_init() def tensor_product_init(self) -> None: with torch.no_grad(): # Determine fan_in for each slice, it could be that each output slice is updated via several instructions slices_fan_in = {} # fan_in per slice for weight, instr in zip(self.tp.weight_views(), self.tp.instructions): slice_idx = instr[2] mul_1, mul_2, mul_out = weight.shape fan_in = mul_1 * mul_2 slices_fan_in[slice_idx] = (slices_fan_in[slice_idx] + fan_in if slice_idx in slices_fan_in.keys() else fan_in) # Do the initialization of the weights in each instruction for weight, instr in zip(self.tp.weight_views(), self.tp.instructions): # The tensor product in e3nn already normalizes proportional to 1 / sqrt(fan_in), and the weights are by # default initialized with unif(-1,1). However, we want to be consistent with torch.nn.Linear and # initialize the weights with unif(-sqrt(k),sqrt(k)), with k = 1 / fan_in if self.tp_rescale: sqrt_k = 1 / slices_fan_in[slice_idx] ** 0.5 else: sqrt_k = 1. weight.data.uniform_(-sqrt_k, sqrt_k) self.slices_sqrt_k[slice_idx] = (self.irreps_out_slices[slice_idx], sqrt_k) # Initialize the biases for (out_slice_idx, out_slice, out_bias) in zip(self.biases_slice_idx, self.biases_slices, self.biases): sqrt_k = 1 / slices_fan_in[out_slice_idx] ** 0.5 out_bias.uniform_(-sqrt_k, sqrt_k) def forward_tp_rescale_bias(self, data_in1, data_in2=None) -> torch.Tensor: if data_in2 == None: data_in2 = torch.ones_like(data_in1[:, 0:1]) data_out = self.tp(data_in1, data_in2) # Apply corrections if self.tp_rescale: for (slice, slice_sqrt_k) in self.slices_sqrt_k.values(): data_out[:,slice] /= slice_sqrt_k # Add the biases for (_, slice, bias) in zip(self.biases_slice_idx, self.biases_slices, self.biases): data_out[:,slice] += bias # Return result return data_out def forward(self, data_in1, data_in2=None) -> torch.Tensor: # Apply the tensor product, the rescaling and the bias data_out = self.forward_tp_rescale_bias(data_in1, data_in2) return data_out