def piecewise_constant(boundaries, values): boundaries = np.array(boundaries) values = np.array(values) if not boundaries.ndim == values.ndim == 1: raise ValueError("boundaries and values must be sequences") if not boundaries.shape[0] == values.shape[0] - 1: raise ValueError("boundaries length must be one shorter than values length") def schedule(i): return values[np.sum(i > boundaries)] return schedule
def testSize(self): def run_test(arr): onp_arr = arr.numpy() if isinstance(arr, tf.Tensor) else arr print(onp_arr) self.assertEqual(np_size(arr), onp.size(onp_arr)) run_test(np.array([1])) run_test(np.array([1, 2, 3, 4, 5])) run_test(np.ones((2, 3, 2))) run_test(np.ones((3, 2))) run_test(np.zeros((5, 6, 7))) run_test(1) run_test(onp.ones((3, 2, 1))) run_test(tf.constant(5)) run_test(tf.constant([1, 1, 1]))
def predict_fn_finite(t, fx_train_0, fx_test_0, k_test_train): t = np.array(t) * learning_rate t_shape, t_ndim = t.shape, t.ndim t = t.reshape((-1, 1)) rhs = -y_train if fx_train_0 is None else fx_train_0 - y_train rhs = np.moveaxis(rhs, trace_axes, last_t_axes).reshape((-1, ) + rhs_shape) shape = t_shape + k_train_train.shape[1::2] + rhs_shape if fx_train_0 is not None: dfx_train = expm1_fn(rhs, t).reshape(shape) dfx_train = np.moveaxis(dfx_train, last_t_axes, trace_axes) fx_train_t = fx_train_0 + dfx_train if fx_test_0 is not None: dfx_test = inv_expm1_fn(rhs, t).reshape(shape) dfx_test = np.tensordot(k_test_train, dfx_test, (odd, non_t_axes)) dfx_test = np.moveaxis( dfx_test, tuple(range(n_non_t_axes, n_non_t_axes + t_ndim)) + last_t_axes, tuple(range(t_ndim)) + trace_axes) fx_test_t = fx_test_0 + dfx_test if fx_train_0 is not None and fx_test_0 is not None: return fx_train_t, fx_test_t if fx_test_0 is None: return fx_train_t return fx_test_t
def testGradientDescentMseEnsembleGet(self, train_shape, test_shape, network, out_logits): _, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape, train_shape) _, _, kernel_fn = _build_network(train_shape[1:], network, out_logits) predictor = predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, diag_reg=0.) for x in [None, 'x_test']: with self.subTest(x=x): x = x if x is None else x_test out = predictor(None, x, 'ntk', compute_cov=True) assert isinstance(out, predict.Gaussian) out = predictor(1., x, 'nngp', compute_cov=True) assert isinstance(out, predict.Gaussian) out = predictor(np.array([0., 1.]), x, ('ntk', ), compute_cov=True) assert len(out) == 1 and isinstance(out[0], predict.Gaussian) out = predictor(2., x, ('ntk', 'nngp'), compute_cov=True) assert (len(out) == 2 and isinstance(out[0], predict.Gaussian) and isinstance(out[1], predict.Gaussian)) out2 = predictor(2., x, ('nngp', 'ntk'), compute_cov=True) self.assertAllClose(out[0], out2[1]) self.assertAllClose(out[1], out2[0])
def testNTKMeanCovPrediction(self, train_shape, test_shape, network, out_logits): key, x_test, x_train, y_train = self._get_inputs( out_logits, test_shape, train_shape) init_fn, f, kernel_fn = stax.serial( stax.Dense(512, W_std=1.2, b_std=0.05), stax.Erf(), stax.Dense(out_logits, W_std=1.2, b_std=0.05)) reg = 1e-6 predictor = predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, diag_reg=reg) ts = np.array([1., 5., 10.]) fx_test_inf, cov_test_inf = predictor(ts, x_test, 'ntk', True) self.assertEqual(cov_test_inf.shape[1], x_test.shape[0]) self.assertGreater(np.min(np.linalg.eigh(cov_test_inf)[0]), -1e-8) fx_train_inf, cov_train_inf = predictor(ts, None, 'ntk', True) self.assertEqual(cov_train_inf.shape[1], x_train.shape[0]) self.assertGreater(np.min(np.linalg.eigh(cov_train_inf)[0]), -1e-8) _kernel_fn = empirical.empirical_kernel_fn(f) kernel_fn = jit( lambda x1, x2, params: _kernel_fn(x1, x2, 'ntk', params)) def predict_empirical(key): _, params = init_fn(key, train_shape) g_dd = kernel_fn(x_train, None, params) g_td = kernel_fn(x_test, x_train, params) predict_fn = predict.gradient_descent_mse(g_dd, y_train, diag_reg=reg) fx_train_0 = f(params, x_train) fx_test_0 = f(params, x_test) return predict_fn(ts, fx_train_0, fx_test_0, g_td) def predict_mc(count, key): key = tf_random_split(key, count) fx_train, fx_test = vmap(predict_empirical)(key) fx_train_mean = np.mean(fx_train, axis=0) fx_test_mean = np.mean(fx_test, axis=0) fx_train_centered = fx_train - fx_train_mean fx_test_centered = fx_test - fx_test_mean cov_train = PredictTest._cov_empirical(fx_train_centered) cov_test = PredictTest._cov_empirical(fx_test_centered) return fx_train_mean, fx_test_mean, cov_train, cov_test fx_train_mc, fx_test_mc, cov_train_mc, cov_test_mc = predict_mc( 4096, key) rtol = 0.05 self._assertAllClose(fx_train_mc, fx_train_inf, rtol) self._assertAllClose(cov_train_mc, cov_train_inf, rtol) self._assertAllClose(cov_test_mc, cov_test_inf, rtol) self._assertAllClose(fx_test_mc, fx_test_inf, rtol)
def apply_fun(params, inputs, **kwargs): inputs = onp.moveaxis(inputs, (batch_dim, channel_dim), \ (0, dim + 1)) output = reduce_window(inputs, init_val, reducer, window_shape, strides, padding) return rescale(out, inputs, spec) if rescale else out # return output return tfnp.array(output)
def ntk_fn(x1: np.ndarray, x2: Optional[np.ndarray], params: PyTree, keys: Union[PRNGKey, Tuple[PRNGKey, PRNGKey], np.ndarray] = None, **apply_fn_kwargs) -> np.ndarray: """Computes a single sample of the empirical NTK (implicit differentiation). Args: x1: first batch of inputs. x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. keys: `None` or a PRNG key or a tuple of PRNG keys or a (2, 2) array of dtype `uint32`. If `key=None`, then the function `f` is deterministic and requires no PRNG key; else if `keys` is a single PRNG key, then `x1` and `x2` must be the same and share the same PRNG key; else `x1` and `x2` use two different PRNG keys. **apply_fn_kwargs: keyword arguments passed to `apply_fn`. Returns: A single sample of the empirical NTK. The shape of the kernel is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. """ key1, key2 = _read_keys(keys) # TODO(xlc): find a good way to check utils.x1_is_x2(x1, x2) == (key1==key2) f1 = _get_f_params(f, x1, key1, **apply_fn_kwargs) f2 = f1 if x2 is None else _get_f_params(f, x2, key2, **apply_fn_kwargs) def delta_vjp_jvp(delta): def delta_vjp(delta): return vjp(f2, params)[1](delta) return _jvp(f1, _tf_to_np((params,)), delta_vjp(delta))[1] # Since we are taking the Jacobian of a linear function (which does not # depend on its coefficients), it is more efficient to substitute fx_dummy # for the outputs of the network. fx_dummy has the same shape as the output # of the network on a single piece of input data. fx2_struct = eval_on_shapes(f2)(params) fx_dummy = np.ones(fx2_struct.shape, dtype=tf.float32) # ntk = jacobian(delta_vjp_jvp)(fx_dummy) with tf.GradientTape() as tape: tape.watch(fx_dummy.data) y = delta_vjp_jvp(fx_dummy.data) ntk = np.array(tape.jacobian(y, fx_dummy.data)) return _index_and_contract(ntk, trace_axes, diagonal_axes)
def get_masked_array(x: ArrayOrList, mask_constant: float = None) -> MaskedArray: """Return `x` with entries equal to `mask_constant` zeroed-out, and the mask. The mask returned is a boolean `np.ndarray` with masked indices having `True`. Args: x: `np.ndarray` to mask. If `x` is a `MaskedInput`, treat it as `(masked_x, mask)` and pass it through. mask_constant: an optional `float`, the value in inputs to be considered as masked (e.g. padding in a batch of sentences). `None` means no masking. Can also be `np.nan`, `np.inf` etc. Returns: A `MaskedArray` of `(masked_x, boolean_mask)`. """ if isinstance(x, list): x_array = [] mask_array = [] for x_ in x: masked_array = get_masked_array(x_, mask_constant) x_array.append(masked_array.masked_value) mask_array.append(masked_array.mask) # fields = zip(*(get_masked_array(_x, mask_constant).astuple() for _x in x)) # return MaskedArray(*(list(f) for f in fields)) return MaskedArray(x_array, mask_array) if x is None: mask = None if isinstance(x, MaskedArray): masked_value = x.masked_value mask = x.mask x = masked_value elif isinstance(x, np.ndarray) or isinstance(x, onp.ndarray): x = np.asarray(x) if mask_constant is None: mask = None else: choice_a = lambda: np.array(tf.math.is_nan(x)) choice_b = lambda: x == mask_constant # mask = choice_a(x) if math.isnan(mask_constant) else choice_b(x) mask = tf.cond(tf.math.is_nan(mask_constant), choice_a, choice_b) else: raise TypeError(x, type(x)) if mask is not None: x = np.where(mask, np.zeros((), x.dtype), x) return MaskedArray(x, mask) # pytype: disable=wrong-arg-count
def evaluate(self, x, y): """Returns the number of correct predictions. Args: x: 2-d array of size batch_size x image_size. y: 2-d array of size batch_size x num_classes. Returns: A scalar, the number of correct predictions. """ y_actual = np.argmax(y, axis=1) y_predicted = np.argmax(self.forward(x), axis=1) correct = int(np.sum(np.array(y_actual == y_predicted))) return correct
def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1): """Compute the shape tuple of a conv given input shapes in canonical order.""" if isinstance(pads, str): pads = padtype_to_pads(lhs_shape[2:], rhs_shape[2:], strides, pads) if len(pads) != len(lhs_shape) - 2: msg = 'Wrong number of explicit pads for convolution: expected {}, got {}.' raise TypeError(msg.format(len(lhs_shape) - 2, len(pads))) lhs_padded = onp.add(lhs_shape[2:], np.sum(np.array(pads).reshape(-1, 2), axis=1)) out_space = np.floor_divide( np.subtract(lhs_padded, rhs_shape[2:]), strides) + 1 out_space = np.maximum(0, out_space) assert lhs_shape[0] % batch_group_count == 0 out_shape = (lhs_shape[0] // batch_group_count, rhs_shape[0]) return tuple(out_shape + tuple(out_space))
def testGradientDescentMseEnsembleTrain(self): key = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) x = np.asarray(normal((8, 4, 6, 3), seed=key)) _, _, kernel_fn = stax.serial(stax.Conv(1, (2, 2)), stax.Relu(), stax.Conv(1, (2, 1))) y = np.asarray(normal((8, 2, 5, 1), seed=key)) predictor = predict.gradient_descent_mse_ensemble(kernel_fn, x, y) for t in [None, np.array([0., 1., 10.])]: with self.subTest(t=t): y_none = predictor(t, None, None, compute_cov=True) y_x = predictor(t, x, None, compute_cov=True) self._assertAllClose(y_none, y_x, 0.04)
def test_tf_dot_general(self, lhs_np, rhs_np, dims): ans = jax.lax.dot_general(lhs_np, rhs_np, dims) result = lax.dot_general(lhs_np, rhs_np, dims) self.assertAllClose(result, tfnp.array(ans))
def predict_fn( t: ArrayOrScalar = None, fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0., fx_test_0: ArrayOrScalar = None, k_test_train: np.ndarray = None ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray], ODEState]: """Return output predictions on train [and test] set[s] at time[s] `t`. Args: t: a scalar or array of scalars of any shape in strictly increasing order. `t=None` is equivalent to `t=np.inf` and may not converge. Equivalent of training steps (but can be fractional). fx_train_or_state_0: either (a) output of the network at `t == 0` on the training set or (b) complete ODE state (`predict.ODEState`). Pass an ODE state if you want to operate on the full ODE state instead of output variables only (useful for inspecting auxiliary variables or resuming an optimizer with auxiliary variables from a specific state. Note that only `momentum != None` optimizer currently has auxiliary variables. To initialize an ODE state from scratch, call `predict.ODEState(fx_train_0, fx_test_0)`. If an ODE state is passed, an ODE state is returned. `fx_train_0=None` means to not compute predictions on the training set. fx_test_0: output of the network at `t == 0` on the test set. `fx_test_0=None` means to not compute predictions on the test set. k_test_train: kernel relating test data with training data. Must have the shape of `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass `k_test_train=None` if you only need predictions on the training set. Returns: `fx_train_t` or `(fx_train_t, fx_test_t)` if `fx_test_0 != None` with potentially additional leading time dimensions matching `t.shape`. Alternatively can return an `ODEState` at time[s] `t`. Raises: ValueError: if `fx_test_0` is not `None`, but `k_test_train` is `None`. """ _check_inputs(fx_train_or_state_0, fx_test_0, k_test_train) t = np.array(t if t is not None else np.inf, dtype) * learning_rate t_shape = t.shape t = t.reshape((-1, )) # ODE solver requires `t[0]` to be the time where `fx_train_0` [and # `fx_test_0`] are evaluated, but also a strictly increasing sequence of # timesteps, so we always temporarily append an [almost] `0` at the start. identity = lambda x: x t0 = np.where(t[0] == 0, np.full((1, ), -1e-24, t.dtype), np.zeros((1, ), t.dtype)) t = np.concatenate([t0, t]) # Solve the ODE. fx_test_shape = _get_fx_test_shape(y_train, k_test_train, trace_axes) state_0 = get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape) state_t = ode.odeint(get_dstate_dt(k_test_train), state_0, t) # Remove the added `t0`. trim = lambda x: x[1:].reshape(t_shape + x.shape[1:]) trim_tree = lambda tree: tree_map(trim, tree) state_t = trim_tree(state_t) # `ODEState` -> `ODEState` if isinstance(fx_train_or_state_0, ODEState): return state_t # `np.ndarray` -> `np.ndarray` fx_train_t, fx_test_t = state_t.fx_train, state_t.fx_test if fx_train_or_state_0 is not None and fx_test_0 is None: return fx_train_t if fx_test_0 is not None and fx_train_or_state_0 is None: return fx_test_t return fx_train_t, fx_test_t
def predict_fn(t: ArrayOrScalar = None, x_test: np.ndarray = None, get: Get = None, compute_cov: bool = False) -> Dict[str, Gaussian]: """Return output mean and covariance on the test set at time[s] `t`. Args: t: a scalar of array of scalars of any shape. `t=None` is treated as infinity and returns the same result as `t=np.inf`, but is computed using linear solve for test predictions instead of eigendecomposition, saving time and precision. x_test: test inputs. `None` means to return non-regularized (`diag_reg=0`) predictions on the train-set inputs. For regularized predictions, pass `x_test=x_train`. get: string, the mode of the Gaussian process, either "nngp" or "ntk", or a tuple. `get=None` is equivalent to `get=("nngp", "ntk")`. compute_cov: if `True` computing both `mean` and `variance` and only `mean` otherwise. Returns: `fx_test_mean_t` or `(fx_test_mean_t, fx_test_cov_t)` if `compute_cov == True` with potentially additional leading time dimensions. """ if get is None: get = ('nngp', 'ntk') # train-train, test-train, test-test. k_dd, k_td, nngp_tt = get_matrices(get, x_test, compute_cov) # Infinite time. if t is None: return predict_inf(get)(get=get, k_test_train=k_td, nngp_test_test=nngp_tt) # Finite time. t = np.array(t) * learning_rate t_shape = t.shape t = t.reshape((-1, 1)) def reshape_mean(mean): k = _get_first(k_dd if k_td is None else k_td) mean = mean.reshape(t_shape + k.shape[::2] + trace_shape) mean = np.moveaxis(mean, last_t_axes, trace_axes) return mean def reshape_cov(cov): k = _get_first(k_dd if k_td is None else k_td) cov_shape_t = t_shape + k.shape[::2] * 2 return utils.zip_axes(cov.reshape(cov_shape_t), len(t_shape)) out = {} for g in get: evals, evecs = eigenspace(g) # Training set. if k_td is None: mean = tf.einsum('ji,ti,ki,k...->tj...', evecs, -expm1(evals, t), evecs, y_train_flat, optimize=True) # Test set. else: neg_inv_expm1 = -inv_expm1(evals, t) ktd_g = utils.make_2d(getattr(k_td, g)) mean = tf.einsum('lj,ji,ti,ki,k...->tl...', ktd_g, evecs, neg_inv_expm1, evecs, y_train_flat, optimize=True) mean = reshape_mean(mean) if nngp_tt is not None: nngp_dd = utils.make_2d(k_dd.nngp) # Training set. if k_td is None: if g == 'nngp': cov = np.einsum('ji,ti,ki->tjk', evecs, (np.maximum(evals, 0.) * np.exp(-2 * np.maximum(evals, 0.) * t / y_train.size)), evecs, optimize=True) elif g == 'ntk': exp = np.einsum('mi,ti,ki->tmk', evecs, np.exp(-np.maximum(evals, 0.) * t / y_train.size), evecs, optimize=True) cov = np.einsum('tmk,kl,tnl->tmn', exp, nngp_dd, exp, optimize=True) else: raise ValueError(g) # Test set. else: _nngp_tt = utils.make_2d(nngp_tt) if g == 'nngp': cov = _nngp_tt - np.einsum('mj,ji,ti,ki,lk->tml', ktd_g, evecs, -inv_expm1(evals, 2 * t), evecs, ktd_g, optimize=True) elif g == 'ntk': term_1 = np.einsum('mi,ti,ki,lk->tml', evecs, neg_inv_expm1, evecs, ktd_g, optimize=True) term_2 = np.einsum( 'mj,ji,ti,ki,lk->tml', ktd_g, evecs, neg_inv_expm1, evecs, utils.make_2d(k_td.nngp), # pytype:disable=attribute-error optimize=True) term_2 += np.moveaxis(term_2, 1, 2) cov = np.einsum('tji,jk,tkl->til', term_1, nngp_dd, term_1, optimize=True) cov += -term_2 + _nngp_tt else: raise ValueError(g) out[g] = Gaussian(mean, reshape_cov(cov)) else: out[g] = mean return out
def test_jit_or_pmap_broadcast(self): def kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.65): res = np.abs(np.matmul(x1, x2)) if do_square: res *= res if do_flip: res = -res res *= stateless_uniform(shape=[], seed=keys) * p return [res, params] params = (np.array([1., 0.3]), (np.array([1.2]), np.array([0.5]))) x2 = np.arange(0, 10).reshape((10, )) keys = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=0) x1 = np.arange(0, 10).reshape((1, 10)) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=0): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=True, p=0.65) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=True) self.assertAllClose(res_1, res_2) test_utils.stub_out_pmap(batch, 1) x1 = np.arange(0, 10).reshape((1, 10)) kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=1) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=1): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=False, p=0.65) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=None) self.assertAllClose(res_1[0], res_2[0]) self.assertAllClose( tree_map(partial(np.expand_dims, axis=0), res_1[1]), res_2[1]) kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=2) x1 = np.arange(0, 20).reshape((2, 10)) test_utils.stub_out_pmap(batch, 2) def broadcast(arg): return np.broadcast_to(arg, (2, ) + arg.shape) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=2): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, p=0.2) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.2) self.assertAllClose(res_1[0][0], res_2[0][0]) self.assertAllClose(res_1[0][1], res_2[0][1]) self.assertAllClose(tree_map(broadcast, res_1[1]), res_2[1])
def testNTK_NTKNNGPAgreement(self, train_shape, test_shape, network, out_logits): _, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape, train_shape) _, _, ker_fun = _build_network(train_shape[1:], network, out_logits) reg = 1e-7 predictor = predict.gradient_descent_mse_ensemble(ker_fun, x_train, y_train, diag_reg=reg) ts = np.logspace(-2, 8, 10).reshape((5, 2)) for t in (None, 'ts'): for x in (None, 'x_test'): with self.subTest(t=t, x=x): x = x if x is None else x_test t = t if t is None else ts ntk = predictor(t=t, get='ntk', x_test=x) # Test time broadcasting if t is not None: ntk_ind = np.array([ predictor(t=t, get='ntk', x_test=x) for t in t.ravel() ]).reshape(t.shape + ntk.shape[2:]) self.assertAllClose(ntk_ind, ntk) # Create a hacked kernel function that always returns the ntk kernel def always_ntk(x1, x2, get=('nngp', 'ntk')): out = ker_fun(x1, x2, get=('nngp', 'ntk')) if get == 'nngp' or get == 'ntk': return out.ntk else: return out._replace(nngp=out.ntk) predictor_ntk = predict.gradient_descent_mse_ensemble( always_ntk, x_train, y_train, diag_reg=reg) ntk_nngp = predictor_ntk(t=t, get='nngp', x_test=x) # Test if you use nngp equations with ntk, you get the same mean self.assertAllClose(ntk, ntk_nngp) # Next test that if you go through the NTK code path, but with only # the NNGP kernel, we recreate the NNGP dynamics. # Create a hacked kernel function that always returns the nngp kernel def always_nngp(x1, x2, get=('nngp', 'ntk')): out = ker_fun(x1, x2, get=('nngp', 'ntk')) if get == 'nngp' or get == 'ntk': return out.nngp else: return out._replace(ntk=out.nngp) predictor_nngp = predict.gradient_descent_mse_ensemble( always_nngp, x_train, y_train, diag_reg=reg) nngp_cov = predictor(t=t, get='nngp', x_test=x, compute_cov=True).covariance # test time broadcasting for covariance nngp_ntk_cov = predictor_nngp(t=t, get='ntk', x_test=x, compute_cov=True).covariance if t is not None: nngp_ntk_cov_ind = np.array([ predictor_nngp(t=t, get='ntk', x_test=x, compute_cov=True).covariance for t in t.ravel() ]).reshape(t.shape + nngp_cov.shape[2:]) self.assertAllClose(nngp_ntk_cov_ind, nngp_ntk_cov) # Test if you use ntk equations with nngp, you get the same cov # Although, due to accumulation of numerical errors, only roughly. self.assertAllClose(nngp_cov, nngp_ntk_cov)
def conv_transpose(lhs, rhs, strides, padding, rhs_dilation=None, dimension_numbers=None, transpose_kernel=False, precision=None): """Convenience wrapper for calculating the N-d convolution "transpose". This function directly calculates a fractionally strided conv rather than indirectly calculating the gradient (transpose) of a forward convolution. Args: lhs: a rank `n+2` dimensional input array. rhs: a rank `n+2` dimensional array of kernel weights. strides: sequence of `n` integers, sets fractional stride. padding: 'SAME', 'VALID' will set as transpose of corresponding forward conv, or a sequence of `n` integer 2-tuples describing before-and-after padding for each `n` spatial dimension. rhs_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `rhs`. RHS dilation is also known as atrous convolution. dimension_numbers: tuple of dimension descriptors as in lax.conv_general_dilated. Defaults to tensorflow convention. transpose_kernel: if True flips spatial axes and swaps the input/output channel axes of the kernel. This makes the output of this function identical to the gradient-derived functions like keras.layers.Conv2DTranspose applied to the same kernel. For typical use in neural nets this is completely pointless and just makes input/output channel specification confusing. precision: Optional. Either `None`, which means the default precision for the backend, or a `Precision` enum value. Returns: Transposed N-d convolution, with output padding following the conventions of keras.layers.Conv2DTranspose. """ assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) > 2 ndims = len(lhs.shape) one = (1,) * (ndims - 2) # Set dimensional layout defaults if not specified. if dimension_numbers is None: if ndims == 3: dimension_numbers = ('NHC', 'HIO', 'NHC') elif ndims == 4: dimension_numbers = ('NHWC', 'HWIO', 'NHWC') elif ndims == 5: dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC') else: raise ValueError('No 4+ dimensional dimension_number defaults.') dn = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers) k_shape = np.take(rhs.shape, dn.rhs_spec) k_sdims = k_shape[2:] # Calculate correct output shape given padding and strides. pads: Union[str, Sequence[Tuple[int, int]]] if padding in {'SAME', 'VALID'}: if rhs_dilation is None: rhs_dilation = (1,) * (rhs.ndim - 2) effective_k_size = map(lambda k, r: (k-1) * r + 1, k_sdims, rhs_dilation) pads = [_conv_transpose_padding(k, s, padding) for k,s in zip(effective_k_size, strides)] else: pads = padding if transpose_kernel: # flip spatial dims and swap input / output channel axes rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:]) rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn)