def ker_fun(kernels): """Compute kernels.""" var1, nngp, var2, ntk, is_gaussian, _ = kernels if nngp.ndim == 2: return kernels var1 = np.mean(var1, axis=(1, 2)) var2 = var2 if var2 is None else np.mean(var2, axis=(1, 2)) if nngp.ndim == 4: nngp = np.mean(nngp, axis=(2, 3)) if _is_array(ntk): ntk = np.mean(ntk, axis=(2, 3)) return Kernel(var1, nngp, var2, ntk, is_gaussian, True) if nngp.ndim == 6: def trace(x): count = x.shape[2] * x.shape[4] y = np.trace(x, axis1=4, axis2=5) z = np.trace(y, axis1=2, axis2=3) return z / count nngp = trace(nngp) if _is_array(ntk): ntk = trace(ntk) return Kernel(var1, nngp, var2, ntk, is_gaussian, True) raise ValueError('`nngp` array must be 2d or 6d.')
def _transform_kernels_ab_relu(kernels, a, b, do_backprop, do_stabilize): """Compute new kernels after an `ABRelu` layer. See https://arxiv.org/pdf/1711.09090.pdf for the leaky ReLU derivation. """ var1, nngp, var2, ntk, _, is_height_width = kernels if do_stabilize: factor = np.max([np.max(np.abs(nngp)), 1e-12]) nngp /= factor var1 /= factor if var2 is not None: var2 /= factor prod = _get_var_prod(var1, nngp, var2) cosines = nngp / np.sqrt(prod) angles = _arccos(cosines, do_backprop) dot_sigma = (a**2 + b**2 - (a - b)**2 * angles / np.pi) / 2 if ntk is not None: ntk *= dot_sigma nngp = ((a - b)**2 * _sqrt(prod - nngp**2, do_backprop) / (2 * np.pi) + dot_sigma * nngp) if do_stabilize: nngp *= factor var1 *= (a**2 + b**2) / 2 if var2 is not None: var2 *= (a**2 + b**2) / 2 if do_stabilize: var1 *= factor var2 *= factor return Kernel(var1, nngp, var2, ntk, a == b, is_height_width)
def parallel_fn_kernel(kernel, *args, **kwargs): n1 = kernel.cov1.shape[0] _device_count = device_count n1_per_device, ragged = divmod(n1, device_count) if n1_per_device and ragged: raise ValueError( ('Dataset size ({}) must divide number of ' 'physical devices ({}).').format(n1, device_count)) elif not n1_per_device: _device_count = ragged n1_per_device = 1 kernel_dict = kernel._asdict() cov2 = kernel_dict['cov2'] cov2_is_none = cov2 is None if cov2 is None: cov2 = kernel_dict['cov1'] kernel_dict['cov2'] = np.broadcast_to(cov2, (_device_count,) + cov2.shape) kernel_dict['x1_is_x2'] = np.broadcast_to( kernel_dict['x1_is_x2'], (_device_count,) + kernel_dict['x1_is_x2'].shape) for k, v in kernel_dict.items(): if k in ('nngp', 'ntk', 'cov1'): kernel_dict[k] = \ np.reshape(v, (_device_count, n1_per_device,) + v.shape[1:]) if k in ('shape1',): kernel_dict[k] = (n1_per_device,) + v[1:] kernel = kernel_fn(Kernel(**kernel_dict), *args, **kwargs) if cov2_is_none: kernel = kernel._replace(cov2=None) return _flatten_kernel(kernel, cov2_is_none, True)
def ker_fun(kernels): """Kernel transformation.""" var1, nngp, var2, ntk, is_gaussian, is_height_width = kernels if not is_height_width: window_shape_nngp = window_shape[::-1] strides_nngp = strides[::-1] else: window_shape_nngp = window_shape strides_nngp = strides nngp = _average_pool_nngp_6d(nngp, window_shape_nngp, strides_nngp, padding) ntk = _average_pool_nngp_6d(ntk, window_shape_nngp, strides_nngp, padding) if var2 is None: var1 = _diagonal_nngp_6d(nngp) else: # TODO(romann) warnings.warn( 'Pooling for different inputs `x1` and `x2` is not ' 'implemented and will only work if there are no ' 'nonlinearities in the network anywhere after the pooling ' 'layer. `var1` and `var2` will have wrong values. ' 'This will be fixed soon.') return Kernel(var1, nngp, var2, ntk, is_gaussian, is_height_width)
def ker_fun(kernels): is_gaussian = all(ker.is_gaussian for ker in kernels) if not is_gaussian: raise NotImplementedError( '`FanInSum` layer is only implemented for the ' 'case if all input layers guaranteed to be mean' '-zero gaussian, i.e. having all `is_gaussian' 'set to `True`.') # If kernels have different height/width order, transpose some of them. n_kernels = len(kernels) n_height_width = sum(ker.is_height_width for ker in kernels) if n_height_width == n_kernels: is_height_width = True elif n_height_width >= n_kernels / 2: is_height_width = True for i in range(n_kernels): if not kernels[i].is_height_width: kernels[i] = _flip_height_width(kernels[i]) else: is_height_width = False for i in range(n_kernels): if kernels[i].is_height_width: kernels[i] = _flip_height_width(kernels[i]) kers = tuple(None if all(ker[i] is None for ker in kernels) else sum( ker[i] for ker in kernels) for i in range(4)) return Kernel(*(kers + (is_gaussian, is_height_width)))
def _move_kernel_to_cpu(kernel): """Moves data in a kernel from an accelerator to the CPU.""" if isinstance(kernel, Kernel): return Kernel(device_get(kernel.var1), device_get(kernel.nngp), device_get(kernel.var2), device_get(kernel.ntk), kernel.is_gaussian, kernel.is_height_width) else: return device_get(kernel)
def _f(_x_or_kernel_np, *_args_np): # Merge Kernel. if is_input_kernel: _x_or_kernel_np = {**_x_or_kernel_np, **x_or_kernel_other} _x_or_kernel_np = Kernel(**_x_or_kernel_np) # Merge args. _args_np = {i: _arg_np for i, _arg_np in zip(args_np_idxs, _args_np)} _args = {**_args_np, **args_other} _args = tuple(v for k, v in sorted(_args.items())) return f(_x_or_kernel_np, *_args, **kwargs)
def new_ker_fun(x1_or_kernel, x2=None, compute_nngp=True, compute_ntk=True): """Returns the `Kernel` resulting from applying `ker_fun` to given inputs. Inputs can be either a pair of `np.ndarray`s, or a `Kernel'. If `n_samples` is positive, `ker_fun` is estimated by Monte Carlo sampling of random networks defined by `(init_fun, apply_fun)`. Args: x1_or_kernel: either a `np.ndarray` with shape `[batch_size_1] + input_shape`, or a `Kernel`. x2: an optional `np.ndarray` with shape `[batch_size_2] + input_shape`. `None` means `x2 == x1` or `x1_or_kernel is Kernel`. compute_nngp: a boolean, `True` to compute NNGP kernel. compute_ntk: a boolean, `True` to compute NTK kernel. Returns: A `Kernel`. """ if (isinstance(x1_or_kernel, Kernel) or (isinstance(x1_or_kernel, list) and all(isinstance(k, Kernel) for k in x1_or_kernel))): kernel = x1_or_kernel elif isinstance(x1_or_kernel, np.ndarray): if x2 is None or isinstance(x2, np.ndarray): if not compute_nngp: if compute_ntk: raise ValueError( 'NNGP has to be computed to compute NTK. Please ' 'set `compute_nngp=True`.') else: return Kernel(None, None, None, None, None, None) use_pooling = getattr(ker_fun, _USE_POOLING, True) kernel = _inputs_to_kernel(x1_or_kernel, x2, use_pooling, compute_ntk) else: raise TypeError('`x2` to a kernel propagation function ' 'should be `None` or a `np.ndarray`, got %s.' % type(x2)) else: raise TypeError( 'Inputs to a kernel propagation function should be ' 'a `Kernel`, ' 'a `list` of `Kernel`s, ' 'or a (tuple of) `np.ndarray`(s), got %s.' % type(x1_or_kernel)) return ker_fun(kernel)
def ker_fun(kernels): """Compute the transformed kernels after a conv layer.""" # var1: batch_1 * height * width # var2: batch_2 * height * width # nngp, ntk: batch_1 * batch_2 * height * height * width * width (pooling) # or batch_1 * batch_2 * height * width (flattening) var1, nngp, var2, ntk, _, is_height_width = kernels if nngp.ndim == 4: def conv_var(x): x = _conv_var_3d(x, filter_shape, strides, padding) x = _affine(x, W_std, b_std) return x def conv_nngp(x): if _is_array(x): x = _conv_nngp_4d(x, filter_shape, strides, padding) x = _affine(x, W_std, b_std) return x elif nngp.ndim == 6: if not is_height_width: filter_shape_nngp = filter_shape[::-1] strides_nngp = strides[::-1] else: filter_shape_nngp = filter_shape strides_nngp = strides def conv_var(x): x = _conv_var_3d(x, filter_shape_nngp, strides_nngp, padding) if x is not None: x = np.transpose(x, (0, 2, 1)) x = _affine(x, W_std, b_std) return x def conv_nngp(x): if _is_array(x): x = _conv_nngp_6d_double_conv(x, filter_shape_nngp, strides_nngp, padding) x = _affine(x, W_std, b_std) return x is_height_width = not is_height_width else: raise ValueError('`nngp` array must be either 4d or 6d, got %d.' % nngp.ndim) var1 = conv_var(var1) var2 = conv_var(var2) nngp = conv_nngp(nngp) ntk = conv_nngp(ntk) + nngp - b_std**2 if ntk is not None else ntk return Kernel(var1, nngp, var2, ntk, True, is_height_width)
def ker_fun(kernels): """Compute the transformed kernels after a dense layer.""" var1, nngp, var2, ntk, _, _ = kernels def fc(x): return _affine(x, W_std, b_std) var1, nngp, var2, ntk = map(fc, (var1, nngp, var2, ntk)) if ntk is not None: ntk += nngp - b_std**2 return Kernel(var1, nngp, var2, ntk, True, True)
def _flatten_kernel(k, x2_is_none, store_on_device, is_parallel): """Flattens a kernel array or a `Kernel` along the batch dimension.""" if hasattr(k, '_asdict'): return k._replace(**_flatten_kernel_dict( k._asdict(), x2_is_none, store_on_device, is_parallel)) if isinstance(k, Kernel): return Kernel(**_flatten_kernel_dict( k.asdict(), x2_is_none, store_on_device, is_parallel)) if isinstance(k, np.ndarray): return _flatten_batch_dimensions(k) raise TypeError( ('Expected kernel to be either a namedtuple, Kernel, or `np.ndarray`, ' 'got %s.') % type(k))
def _flatten_kernel(k): """Flattens a kernel array or a `Kernel` along the batch dimension.""" if isinstance(k, Kernel): return Kernel( _flatten_batch_dimensions(k.var1, discard_axis=1), _flatten_batch_dimensions(k.nngp), _flatten_batch_dimensions(k.var2, discard_axis=0), _flatten_batch_dimensions(k.ntk), np.all(k.is_gaussian) if k.is_gaussian is not None else None, np.all(k.is_height_width) if k.is_height_width is not None else None) elif isinstance(k, np.ndarray): return _flatten_batch_dimensions(k) else: raise TypeError( 'Expected kernel to be either a `Kernel` or a `np.ndarray`, got %s.' % type(k))
def _flatten_kernel(k: _KernelType, x2_is_none: bool, is_parallel: bool) -> _KernelType: """Flattens a kernel array or a `Kernel` along the batch dimension.""" if hasattr(k, '_asdict'): return k._replace(**_flatten_kernel_dict(k._asdict(), x2_is_none, is_parallel)) elif isinstance(k, Kernel): # pytype:disable=attribute-error return Kernel(**_flatten_kernel_dict(k.asdict(), x2_is_none, is_parallel)) # pytype:enable=attribute-error elif isinstance(k, np.ndarray): return _flatten_batch_dimensions(k) raise TypeError(f'Expected kernel to be either a namedtuple, `Kernel`, or ' f'`np.ndarray`, got {type(k)}.')
def _transform_kernels_erf(kernels, do_backprop): """Compute new kernels after an `Erf` layer.""" var1, nngp, var2, ntk, _, is_height_width = kernels _var1_denom = 1 + 2 * var1 _var2_denom = None if var2 is None else 1 + 2 * var2 prod = _get_var_prod(_var1_denom, nngp, _var2_denom) dot_sigma = 4 / (np.pi * np.sqrt(prod - 4 * nngp**2)) if ntk is not None: ntk *= dot_sigma nngp = _arcsin(2 * nngp / np.sqrt(prod), do_backprop) * 2 / np.pi var1 = np.arcsin(2 * var1 / _var1_denom) * 2 / np.pi if var2 is not None: var2 = np.arcsin(2 * var2 / _var2_denom) * 2 / np.pi return Kernel(var1, nngp, var2, ntk, False, is_height_width)
def ker_fun(kernels): var1, nngp, var2, ntk, is_gaussian, _ = kernels pixel_axes = tuple(range(2, nngp.ndim)) nngp = np.mean(nngp, axis=pixel_axes) ntk = np.mean(ntk, axis=pixel_axes) if _is_array(ntk) else ntk if var2 is None: var1 = np.diagonal(nngp) else: # TODO(romann) warnings.warn( 'Pooling for different inputs `x1` and `x2` is not ' 'implemented and will only work if there are no ' 'nonlinearities in the network anywhere after the pooling ' 'layer. `var1` and `var2` will have wrong values. ' 'This will be fixed soon.') return Kernel(var1, nngp, var2, ntk, is_gaussian, True)
def _flip_height_width(kernels): """Flips the order of spatial axes in the covariance matrices. Args: kernels: a `Kernel` object. Returns: A `Kernel` object with `height` and `width` axes order flipped in all covariance matrices. For example, if `kernels.nngp` has shape `[batch_size_1, batch_size_2, height, height, width, width]`, then `_flip_height_width(kernels).nngp` has shape `[batch_size_1, batch_size_2, width, width, height, height]`. """ var1, nngp, var2, ntk, is_gaussian, is_height_width = kernels var1 = np.transpose(var1, (0, 2, 1)) var2 = np.transpose(var2, (0, 2, 1)) if var2 is not None else var2 nngp = np.transpose(nngp, (0, 1, 4, 5, 2, 3)) ntk = np.transpose(ntk, (0, 1, 4, 5, 2, 3)) if ntk is not None else ntk return Kernel(var1, nngp, var2, ntk, is_gaussian, not is_height_width)
def get_sampled_kernel(x1, x2, key, n_samples): if x2 is not None: assert x1.shape[1:] == x2.shape[1:] if key.shape == (2, ): key = random.split(key, n_samples) elif n_samples is not None: raise ValueError('Got set `n_samples=%d` and %d RNG keys.' % (n_samples, key.shape[0])) ker_sampled = Kernel(var1=None, nngp=0. if compute_nngp else None, var2=None, ntk=0. if compute_ntk else None, is_gaussian=None, is_height_width=None) for subkey in key: ker_sampled += ker_fun_sample_once(x1, x2, subkey) return ker_sampled / len(key)
def ker_fun(x1, x2, params): return Kernel(None, nngp_fun(x1, x2, params), None, ntk_fun(x1, x2, params), None, None, None, None)
def _inputs_to_kernel(x1, x2, use_pooling, compute_ntk): """Transforms (batches of) inputs to a `Kernel`. This is a private method. Docstring and example are for internal reference. The kernel contains the empirical covariances between different inputs and their entries (pixels) necessary to compute the covariance of the Gaussian Process corresponding to an infinite Bayesian or gradient-flow-trained neural network. The smallest necessary number of covariance entries is tracked. For example, all networks are assumed to have i.i.d. weights along the channel / feature / logits dimensions, hence covariance between different entries along these dimensions is known to be 0 and is not tracked. Args: x1: a 2D `np.ndarray` of shape `[batch_size_1, n_features]` (dense network) or 4D of shape `[batch_size_1, height, width, channels]` (conv-nets). x2: an optional `np.ndarray` with the same shape as `x1` apart from possibly different leading batch size. `None` means `x2 == x1`. use_pooling: a boolean, indicating whether pooling will be used somewhere in the model. If so, more covariance entries need to be tracked. Is set automatically based on the network topology. Specifically, is set to `False` if a `serial` or `parallel` networks contain a `Flatten` layer and no pooling layers (`AvgPool` or `GlobalAvgPool`). Has no effect for non-convolutional models. compute_ntk: a boolean, `True` to compute both NTK and NNGP kernels, `False` to only compute NNGP. Example: ```python >>> x = np.ones((10, 32, 16, 3)) >>> _inputs_to_kernel(x, None, use_pooling=True, >>> compute_ntk=True).ntk.shape (10, 10, 32, 32, 16, 16) >>> _inputs_to_kernel(x, None, use_pooling=False, >>> compute_ntk=True).ntk.shape (10, 10, 32, 16) >>> x1 = np.ones((10, 128)) >>> x2 = np.ones((20, 128)) >>> _inputs_to_kernel(x, None, use_pooling=True, >>> compute_ntk=False).nngp.shape (10, 20) >>> _inputs_to_kernel(x, None, use_pooling=False, >>> compute_ntk=False).nngp.shape (10, 20) >>> _inputs_to_kernel(x, None, use_pooling=False, >>> compute_ntk=False).ntk None ``` Returns: a `Kernel` object. """ x1 = x1.astype(xla_bridge.canonicalize_dtype(np.float64)) var1 = _get_variance(x1) if x2 is None: x2 = x1 var2 = None else: if x1.shape[1:] != x2.shape[1:]: raise ValueError( '`x1` and `x2` are expected to be batches of' ' inputs with the same shape (apart from the batch size),' ' got %s and %s.' % (str(x1.shape), str(x2.shape))) x2 = x2.astype(xla_bridge.canonicalize_dtype(np.float64)) var2 = _get_variance(x2) if use_pooling and x1.ndim == 4: x2 = np.expand_dims(x2, -1) nngp = np.dot(x1, x2) / x1.shape[-1] nngp = np.transpose(np.squeeze(nngp, -1), (0, 3, 1, 4, 2, 5)) elif x1.ndim == 4 or x1.ndim == 2: nngp = _batch_uncentered_covariance(x1, x2) else: raise ValueError('Inputs must be 2D or 4D `np.ndarray`s of shape ' '`[batch_size, n_features]` or ' '`[batch_size, height, width, channels]`, ' 'got %s.' % str(x1.shape)) ntk = 0. if compute_ntk else None is_gaussian = False is_height_width = True return Kernel(var1, nngp, var2, ntk, is_gaussian, is_height_width)