def _conv_dimension_numbers(input_shape): """Computes the dimension numbers based on the input shape.""" ndim = len(input_shape) lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1)) rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2)) out_spec = lhs_spec return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
def _conv_dim_nums(n: int, s2: Tuple[int, ...]): dims = itertools.permutations(range(n)) dns = [] for i in itertools.product(dims, repeat=3): dn = lax.ConvDimensionNumbers(*i) if all(s2[s] != 0 for s in dn[1][2:]): dns += [dn] return random.sample(dns, min(50, len(dns)))
def _conv_dimension_numbers(input_shape): """DEPRECATION WARNING: The `flax.nn` module is Deprecated, use `flax.linen` instead. Learn more and find an upgrade guide at https://github.com/google/flax/blob/main/flax/linen/README.md" Computes the dimension numbers based on the input shape.""" ndim = len(input_shape) lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1)) rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2)) out_spec = lhs_spec return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
def to_dimension_numbers( num_spatial_dims: int, channels_last: bool, transpose: bool, ) -> lax.ConvDimensionNumbers: """Create a `lax.ConvDimensionNumbers` for the given inputs.""" num_dims = num_spatial_dims + 2 if channels_last: spatial_dims = tuple(range(1, num_dims - 1)) image_dn = (0, num_dims - 1) + spatial_dims else: spatial_dims = tuple(range(2, num_dims)) image_dn = (0, 1) + spatial_dims if transpose: kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2)) else: kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2)) return lax.ConvDimensionNumbers(lhs_spec=image_dn, rhs_spec=kernel_dn, out_spec=image_dn)