def __init__( self, in_channels, out_channels, dim, kernel_size, hidden_channels=None, dilation=1, bias=True, **kwargs, ): super(XConv, self).__init__() self.in_channels = in_channels if hidden_channels is None: hidden_channels = in_channels // 4 assert hidden_channels > 0 self.hidden_channels = hidden_channels self.out_channels = out_channels self.dim = dim self.kernel_size = kernel_size self.dilation = dilation self.kwargs = kwargs C_in, C_delta, C_out = in_channels, hidden_channels, out_channels D, K = dim, kernel_size self.mlp1 = S( L(dim, C_delta), ELU(), BN(C_delta), L(C_delta, C_delta), ELU(), BN(C_delta), Reshape(-1, K, C_delta), ) self.mlp2 = S( L(D * K, K**2), ELU(), BN(K**2), Reshape(-1, K, K), Conv1d(K, K**2, K, groups=K), ELU(), BN(K**2), Reshape(-1, K, K), Conv1d(K, K**2, K, groups=K), BN(K**2), Reshape(-1, K, K), ) C_in = C_in + C_delta depth_multiplier = int(ceil(C_out / C_in)) self.conv = S( Conv1d(C_in, C_in * depth_multiplier, K, groups=C_in), Reshape(-1, C_in * depth_multiplier), L(C_in * depth_multiplier, C_out, bias=bias), ) self.reset_parameters()
def __init__(self, in_channels, out_channels, dim, kernel_size, hidden_channels=None, dilation=1, bias=True, BiLinear=BiLinear, BiConv1d=BiConv1d, ifFirst=False, **kwargs): super(BiXConv, self).__init__() if knn_graph is None: raise ImportError('`XConv` requires `torch-cluster`.') self.in_channels = in_channels if hidden_channels is None: hidden_channels = in_channels // 4 assert hidden_channels > 0 self.hidden_channels = hidden_channels self.out_channels = out_channels self.dim = dim self.kernel_size = kernel_size self.dilation = dilation self.kwargs = kwargs C_in, C_delta, C_out = in_channels, hidden_channels, out_channels D, K = dim, kernel_size if ifFirst: Lin1 = Lin else: Lin1 = BiLinear self.mlp1 = S( Lin1(dim, C_delta), Hardtanh(), BN(C_delta), BiLinear(C_delta, C_delta), Hardtanh(), BN(C_delta), Reshape(-1, K, C_delta), ) self.mlp2 = S( Lin1(D * K, K**2), Hardtanh(), BN(K**2), Reshape(-1, K, K), BiConv1d(K, K**2, K, groups=K), Hardtanh(), BN(K**2), Reshape(-1, K, K), BiConv1d(K, K**2, K, groups=K), BN(K**2), Reshape(-1, K, K), ) C_in = C_in + C_delta depth_multiplier = int(ceil(C_out / C_in)) self.conv = S( BiConv1d(C_in, C_in * depth_multiplier, K, groups=C_in), Reshape(-1, C_in * depth_multiplier), BiLinear(C_in * depth_multiplier, C_out, bias=bias), ) self.reset_parameters()
def __init__(self, in_channels: int, out_channels: int, dim: int, kernel_size: int, hidden_channels: Optional[int] = None, dilation: int = 1, bias: bool = True, num_workers: int = 1): super(XConv, self).__init__() if knn_graph is None: raise ImportError('`XConv` requires `torch-cluster`.') self.in_channels = in_channels if hidden_channels is None: hidden_channels = in_channels // 4 assert hidden_channels > 0 self.hidden_channels = hidden_channels self.out_channels = out_channels self.dim = dim self.kernel_size = kernel_size self.dilation = dilation self.num_workers = num_workers C_in, C_delta, C_out = in_channels, hidden_channels, out_channels D, K = dim, kernel_size self.mlp1 = S( L(dim, C_delta), ELU(), BN(C_delta), L(C_delta, C_delta), ELU(), BN(C_delta), Reshape(-1, K, C_delta), ) self.mlp2 = S( L(D * K, K**2), ELU(), BN(K**2), Reshape(-1, K, K), Conv1d(K, K**2, K, groups=K), ELU(), BN(K**2), Reshape(-1, K, K), Conv1d(K, K**2, K, groups=K), BN(K**2), Reshape(-1, K, K), ) C_in = C_in + C_delta depth_multiplier = int(ceil(C_out / C_in)) self.conv = S( Conv1d(C_in, C_in * depth_multiplier, K, groups=C_in), Reshape(-1, C_in * depth_multiplier), L(C_in * depth_multiplier, C_out, bias=bias), ) self.reset_parameters()