def convert_bn_layer(keras_layer, objax_layer): """Converts variables of batch normalization layer from Keras to Objax.""" shape = objax_layer.gamma.value.shape objax_layer.gamma = objax.TrainVar(jn.array(keras_layer.__dict__['gamma'].numpy()).reshape(shape)) objax_layer.beta = objax.TrainVar(jn.array(keras_layer.__dict__['beta'].numpy()).reshape(shape)) objax_layer.running_mean = objax.StateVar(jn.array(keras_layer.__dict__['moving_mean'].numpy()).reshape(shape)) objax_layer.running_var = objax.StateVar(jn.array(keras_layer.__dict__['moving_variance'].numpy()).reshape(shape))
def test_state_var_custom_reduce(self): """Test Custom StateVar reducing.""" array = jn.array([[0, 4, 2], [3, 1, 5]], dtype=jn.float32) v = objax.StateVar(array, reduce=lambda x: x[0]) v.reduce(v.value) self.assertEqual(v.value.tolist(), [0, 4, 2]) v = objax.StateVar(array, reduce=lambda x: x.min(0)) v.reduce(v.value) self.assertEqual(v.value.tolist(), [0, 1, 2]) v = objax.StateVar(array, reduce=lambda x: x.max(0)) v.reduce(v.value) self.assertEqual(v.value.tolist(), [3, 4, 5])
def __init__(self, temporal_kernel, spatial_kernel, z=None, conditional=None, sparse=True, opt_z=False, spatial_dims=None): self.temporal_kernel = temporal_kernel self.spatial_kernel = spatial_kernel if conditional is None: if sparse: conditional = 'Full' else: conditional = 'DTC' if opt_z and ( not sparse ): # z should not be optimised if the model is not sparse warn( "spatial inducing inputs z will not be optimised because sparse=False" ) opt_z = False self.sparse = sparse if z is None: # initialise z # TODO: smart initialisation if spatial_dims == 1: z = np.linspace(-3., 3., num=15) elif spatial_dims == 2: z1 = np.linspace(-3., 3., num=5) zA, zB = np.meshgrid( z1, z1) # Adding additional dimension to inducing points grid z = np.hstack((zA.reshape(-1, 1), zB.reshape( -1, 1))) # Flattening grid for use in kernel functions else: raise NotImplementedError( 'please provide an initialisation for inducing inputs z') if z.ndim < 2: z = z[:, np.newaxis] if spatial_dims is None: spatial_dims = z.ndim - 1 assert spatial_dims == z.ndim - 1 self.M = z.shape[0] if opt_z: self.z = objax.TrainVar(z) # .reshape(-1, 1) else: self.z = objax.StateVar(z) if conditional in ['DTC', 'dtc']: self.conditional_covariance = self.deterministic_training_conditional elif conditional in ['FIC', 'FITC', 'fic', 'fitc']: self.conditional_covariance = self.fully_independent_conditional elif conditional in ['Full', 'full']: self.conditional_covariance = self.full_conditional else: raise NotImplementedError('conditional method not recognised') if (not sparse) and (conditional != 'DTC'): warn( "You chose a non-deterministic conditional, but \'DTC\' will be used because the model is not sparse" )
def __init__(self, variance, lengthscale, fix_variance=False, fix_lengthscale=False): # check whether the parameters are to be optimised if fix_lengthscale: self.transformed_lengthscale = objax.StateVar( softplus_inv(np.array(lengthscale))) else: self.transformed_lengthscale = objax.TrainVar( softplus_inv(np.array(lengthscale))) if fix_variance: self.transformed_variance = objax.StateVar( softplus_inv(np.array(variance))) else: self.transformed_variance = objax.TrainVar( softplus_inv(np.array(variance)))
def __init__(self, kernel, likelihood, X, Y, Z, opt_z=False): super().__init__(kernel=kernel, likelihood=likelihood, X=X, Y=Y) if Z.ndim < 2: Z = Z[:, None] if opt_z: self.Z = objax.TrainVar(Z) else: self.Z = objax.StateVar(Z) self.num_inducing = Z.shape[0] self.posterior_mean = objax.StateVar(np.zeros([self.num_inducing, self.func_dim, 1])) self.posterior_variance = objax.StateVar(np.tile(np.eye(self.func_dim), [self.num_inducing, 1, 1])) self.posterior_covariance = objax.StateVar(np.eye(self.num_inducing))
def test_state_var(self): """Test StateVar behavior.""" v = objax.StateVar(jn.arange(6, dtype=jn.float32).reshape((2, 3))) self.assertEqual(v.value.shape, (2, 3)) self.assertEqual(v.value.sum(), 15) v.value += 2 self.assertEqual(v.value.sum(), 27) v.assign(v.value - 1) self.assertEqual(v.value.sum(), 21) v.reduce(v.value) self.assertEqual(v.value.shape, (3,)) self.assertEqual(v.value.tolist(), [2.5, 3.5, 4.5])
def __init__(self, kernel, likelihood, X, Y, func_dim=1): if X.ndim < 2: X = X[:, None] if Y.ndim < 2: Y = Y[:, None] self.X = np.array(X) self.Y = np.array(Y) self.kernel = kernel self.likelihood = likelihood self.num_data = self.X.shape[0] # number of data self.func_dim = func_dim # number of latent dimensions self.obs_dim = Y.shape[1] # dimensionality of the observations, Y self.mask = np.isnan(self.Y).reshape(Y.shape[0], Y.shape[1]) if isinstance(self.kernel, Independent): pseudo_lik_size = self.func_dim # the multi-latent case else: pseudo_lik_size = self.obs_dim self.pseudo_likelihood_nat1 = objax.StateVar(np.zeros([self.num_data, pseudo_lik_size, 1])) self.pseudo_likelihood_nat2 = objax.StateVar(1e-2 * np.tile(np.eye(pseudo_lik_size), [self.num_data, 1, 1])) self.pseudo_y = objax.StateVar(np.zeros([self.num_data, pseudo_lik_size, 1])) self.pseudo_var = objax.StateVar(1e2 * np.tile(np.eye(pseudo_lik_size), [self.num_data, 1, 1])) self.posterior_mean = objax.StateVar(np.zeros([self.num_data, self.func_dim, 1])) self.posterior_variance = objax.StateVar(np.tile(np.eye(self.func_dim), [self.num_data, 1, 1])) self.ind = np.arange(self.num_data) self.num_neighbours = np.ones(self.num_data)
def __init__(self, variance=1.0, lengthscale=1.0, radial_frequency=1.0, fix_variance=False): self.transformed_lengthscale = objax.TrainVar( np.array(softplus_inv(lengthscale))) if fix_variance: self.transformed_variance = objax.StateVar( np.array(softplus_inv(variance))) else: self.transformed_variance = objax.TrainVar( np.array(softplus_inv(variance))) self.transformed_radial_frequency = objax.TrainVar( np.array(softplus_inv(radial_frequency))) super().__init__() self.name = 'Subband Matern-1/2'
def test_vars(self): t = objax.TrainVar(jn.zeros([1, 2, 3, 2, 1])) tv = '\n'.join(['objax.TrainVar(DeviceArray([[[[[0.],', ' [0.]],', ' [[0.],', ' [0.]],', ' [[0.],', ' [0.]]],', ' [[[0.],', ' [0.]],', ' [[0.],', ' [0.]],', ' [[0.],', ' [0.]]]]], dtype=float32), reduce=reduce_mean)']) self.assertEqual(repr(t), tv) r = objax.TrainRef(t) rv = '\n'.join(['objax.TrainRef(ref=objax.TrainVar(DeviceArray([[[[[0.],', ' [0.]],', ' [[0.],', ' [0.]],', ' [[0.],', ' [0.]]],', ' [[[0.],', ' [0.]],', ' [[0.],', ' [0.]],', ' [[0.],', ' [0.]]]]], dtype=float32), reduce=reduce_mean))']) self.assertEqual(repr(r), rv) t = objax.StateVar(jn.zeros([1, 2, 3, 2, 1])) tv = '\n'.join(['objax.StateVar(DeviceArray([[[[[0.],', ' [0.]],', ' [[0.],', ' [0.]],', ' [[0.],', ' [0.]]],', ' [[[0.],', ' [0.]],', ' [[0.],', ' [0.]],', ' [[0.],', ' [0.]]]]], dtype=float32), reduce=reduce_mean)']) self.assertEqual(repr(t), tv) self.assertEqual(repr(objax.random.Generator().key), 'objax.RandomState(DeviceArray([0, 0], dtype=uint32))')
def test_var_hierarchy(self): """Test variable hierarchy.""" t = objax.TrainVar(jn.zeros(2)) s = objax.StateVar(jn.zeros(2)) r = objax.TrainRef(t) x = objax.RandomState(0) self.assertIsInstance(t, objax.TrainVar) self.assertIsInstance(t, objax.BaseVar) self.assertNotIsInstance(t, objax.BaseState) self.assertIsInstance(s, objax.BaseVar) self.assertIsInstance(s, objax.BaseState) self.assertNotIsInstance(s, objax.TrainVar) self.assertIsInstance(r, objax.BaseVar) self.assertIsInstance(r, objax.BaseState) self.assertNotIsInstance(r, objax.TrainVar) self.assertIsInstance(x, objax.BaseVar) self.assertIsInstance(x, objax.BaseState) self.assertIsInstance(x, objax.StateVar) self.assertNotIsInstance(x, objax.TrainVar)
def __init__(self, kernel, likelihood, X, Y, R=None, Z=None): super().__init__(kernel=kernel, likelihood=likelihood, X=X, Y=Y, R=R) if Z is None: Z = self.X else: if Z.ndim < 2: Z = Z[:, None] Z = np.sort(Z, axis=0) inf = np.array([[1e10]]) self.Z = objax.StateVar(np.concatenate([-inf, Z, inf], axis=0)) self.dz = np.array(np.diff(self.Z.value[:, 0])) self.num_transitions = self.dz.shape[0] zeros = np.zeros([self.num_transitions, 2 * self.state_dim, 1]) eyes = np.tile(np.eye(2 * self.state_dim), [self.num_transitions, 1, 1]) # nat2 = 1e-8 * eyes # initialise to match MarkovGP / GP on first step (when Z=X): nat2 = index_update(1e-8 * eyes, index[:-1, self.state_dim, self.state_dim], 1e-2) # initialise to match old implementation: # nat2 = (1 / 99) * eyes self.pseudo_likelihood_nat1 = objax.StateVar(zeros) self.pseudo_likelihood_nat2 = objax.StateVar(nat2) self.pseudo_y = objax.StateVar(zeros) self.pseudo_var = objax.StateVar(vmap(inv)(nat2)) self.posterior_mean = objax.StateVar(zeros) self.posterior_variance = objax.StateVar(eyes) self.mask = None self.conditional_mean = None # TODO: if training Z this needs to be done at every training step (as well as sorting and computing dz) self.ind, self.num_neighbours = set_z_stats(self.X, self.Z.value)