def padtype_to_pads(in_shape, window_shape, window_strides, padding): if padding == 'SAME': out_shape = _ceil_divide(in_shape, window_strides) pad_sizes = np.maximum(0, (out_shape - 1) * window_strides + window_shape - in_shape) return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] elif padding == 'VALID': return [(0, 0)] * len(in_shape)
def assert_close_matrices(self, expected, actual, rtol): self.assertEqual(expected.shape, actual.shape) relative_error = (tf.linalg.norm(actual - expected) / np.maximum(tf.linalg.norm(expected), 1e-12)) if relative_error > rtol or np.isnan(relative_error): _log(relative_error, expected, actual, False) self.fail( self.failureException('Relative ERROR: ', float(relative_error), 'EXPECTED:' + ' ' * 50, expected, 'ACTUAL:' + ' ' * 50, actual)) else: _log(relative_error, expected, actual, True)
def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1): """Compute the shape tuple of a conv given input shapes in canonical order.""" if isinstance(pads, str): pads = padtype_to_pads(lhs_shape[2:], rhs_shape[2:], strides, pads) if len(pads) != len(lhs_shape) - 2: msg = 'Wrong number of explicit pads for convolution: expected {}, got {}.' raise TypeError(msg.format(len(lhs_shape) - 2, len(pads))) lhs_padded = onp.add(lhs_shape[2:], np.sum(np.array(pads).reshape(-1, 2), axis=1)) out_space = np.floor_divide( np.subtract(lhs_padded, rhs_shape[2:]), strides) + 1 out_space = np.maximum(0, out_space) assert lhs_shape[0] % batch_group_count == 0 out_shape = (lhs_shape[0] // batch_group_count, rhs_shape[0]) return tuple(out_shape + tuple(out_space))
def update(i, g, state): x, m, u = state m = (1 - b1) * g + b1 * m # First moment estimate. u = np.maximum(b2 * u, np.abs(g)) # Update exponentially weighted infinity norm. x = x - (step_size(i) / (1 - b1 ** (i + 1))) * m / (u + eps) return x, m, u
def predict_fn(t: ArrayOrScalar = None, x_test: np.ndarray = None, get: Get = None, compute_cov: bool = False) -> Dict[str, Gaussian]: """Return output mean and covariance on the test set at time[s] `t`. Args: t: a scalar of array of scalars of any shape. `t=None` is treated as infinity and returns the same result as `t=np.inf`, but is computed using linear solve for test predictions instead of eigendecomposition, saving time and precision. x_test: test inputs. `None` means to return non-regularized (`diag_reg=0`) predictions on the train-set inputs. For regularized predictions, pass `x_test=x_train`. get: string, the mode of the Gaussian process, either "nngp" or "ntk", or a tuple. `get=None` is equivalent to `get=("nngp", "ntk")`. compute_cov: if `True` computing both `mean` and `variance` and only `mean` otherwise. Returns: `fx_test_mean_t` or `(fx_test_mean_t, fx_test_cov_t)` if `compute_cov == True` with potentially additional leading time dimensions. """ if get is None: get = ('nngp', 'ntk') # train-train, test-train, test-test. k_dd, k_td, nngp_tt = get_matrices(get, x_test, compute_cov) # Infinite time. if t is None: return predict_inf(get)(get=get, k_test_train=k_td, nngp_test_test=nngp_tt) # Finite time. t = np.array(t) * learning_rate t_shape = t.shape t = t.reshape((-1, 1)) def reshape_mean(mean): k = _get_first(k_dd if k_td is None else k_td) mean = mean.reshape(t_shape + k.shape[::2] + trace_shape) mean = np.moveaxis(mean, last_t_axes, trace_axes) return mean def reshape_cov(cov): k = _get_first(k_dd if k_td is None else k_td) cov_shape_t = t_shape + k.shape[::2] * 2 return utils.zip_axes(cov.reshape(cov_shape_t), len(t_shape)) out = {} for g in get: evals, evecs = eigenspace(g) # Training set. if k_td is None: mean = tf.einsum('ji,ti,ki,k...->tj...', evecs, -expm1(evals, t), evecs, y_train_flat, optimize=True) # Test set. else: neg_inv_expm1 = -inv_expm1(evals, t) ktd_g = utils.make_2d(getattr(k_td, g)) mean = tf.einsum('lj,ji,ti,ki,k...->tl...', ktd_g, evecs, neg_inv_expm1, evecs, y_train_flat, optimize=True) mean = reshape_mean(mean) if nngp_tt is not None: nngp_dd = utils.make_2d(k_dd.nngp) # Training set. if k_td is None: if g == 'nngp': cov = np.einsum('ji,ti,ki->tjk', evecs, (np.maximum(evals, 0.) * np.exp(-2 * np.maximum(evals, 0.) * t / y_train.size)), evecs, optimize=True) elif g == 'ntk': exp = np.einsum('mi,ti,ki->tmk', evecs, np.exp(-np.maximum(evals, 0.) * t / y_train.size), evecs, optimize=True) cov = np.einsum('tmk,kl,tnl->tmn', exp, nngp_dd, exp, optimize=True) else: raise ValueError(g) # Test set. else: _nngp_tt = utils.make_2d(nngp_tt) if g == 'nngp': cov = _nngp_tt - np.einsum('mj,ji,ti,ki,lk->tml', ktd_g, evecs, -inv_expm1(evals, 2 * t), evecs, ktd_g, optimize=True) elif g == 'ntk': term_1 = np.einsum('mi,ti,ki,lk->tml', evecs, neg_inv_expm1, evecs, ktd_g, optimize=True) term_2 = np.einsum( 'mj,ji,ti,ki,lk->tml', ktd_g, evecs, neg_inv_expm1, evecs, utils.make_2d(k_td.nngp), # pytype:disable=attribute-error optimize=True) term_2 += np.moveaxis(term_2, 1, 2) cov = np.einsum('tji,jk,tkl->til', term_1, nngp_dd, term_1, optimize=True) cov += -term_2 + _nngp_tt else: raise ValueError(g) out[g] = Gaussian(mean, reshape_cov(cov)) else: out[g] = mean return out
def expm1_fn(evals: np.ndarray, t: np.ndarray): # Since our matrix really should be positive semidefinite, # we can threshold the eigenvalues to squash ones that are negative # for numerical reasons. return np.expm1(-np.maximum(evals, 0.) * t / normalization)