Пример #1
0
def convert_bn_layer(keras_layer, objax_layer):
    """Converts variables of batch normalization layer from Keras to Objax."""
    shape = objax_layer.gamma.value.shape
    objax_layer.gamma = objax.TrainVar(jn.array(keras_layer.__dict__['gamma'].numpy()).reshape(shape))
    objax_layer.beta = objax.TrainVar(jn.array(keras_layer.__dict__['beta'].numpy()).reshape(shape))
    objax_layer.running_mean = objax.StateVar(jn.array(keras_layer.__dict__['moving_mean'].numpy()).reshape(shape))
    objax_layer.running_var = objax.StateVar(jn.array(keras_layer.__dict__['moving_variance'].numpy()).reshape(shape))
Пример #2
0
 def test_state_var_custom_reduce(self):
     """Test Custom StateVar reducing."""
     array = jn.array([[0, 4, 2], [3, 1, 5]], dtype=jn.float32)
     v = objax.StateVar(array, reduce=lambda x: x[0])
     v.reduce(v.value)
     self.assertEqual(v.value.tolist(), [0, 4, 2])
     v = objax.StateVar(array, reduce=lambda x: x.min(0))
     v.reduce(v.value)
     self.assertEqual(v.value.tolist(), [0, 1, 2])
     v = objax.StateVar(array, reduce=lambda x: x.max(0))
     v.reduce(v.value)
     self.assertEqual(v.value.tolist(), [3, 4, 5])
Пример #3
0
 def __init__(self,
              temporal_kernel,
              spatial_kernel,
              z=None,
              conditional=None,
              sparse=True,
              opt_z=False,
              spatial_dims=None):
     self.temporal_kernel = temporal_kernel
     self.spatial_kernel = spatial_kernel
     if conditional is None:
         if sparse:
             conditional = 'Full'
         else:
             conditional = 'DTC'
     if opt_z and (
             not sparse
     ):  # z should not be optimised if the model is not sparse
         warn(
             "spatial inducing inputs z will not be optimised because sparse=False"
         )
         opt_z = False
     self.sparse = sparse
     if z is None:  # initialise z
         # TODO: smart initialisation
         if spatial_dims == 1:
             z = np.linspace(-3., 3., num=15)
         elif spatial_dims == 2:
             z1 = np.linspace(-3., 3., num=5)
             zA, zB = np.meshgrid(
                 z1,
                 z1)  # Adding additional dimension to inducing points grid
             z = np.hstack((zA.reshape(-1, 1), zB.reshape(
                 -1, 1)))  # Flattening grid for use in kernel functions
         else:
             raise NotImplementedError(
                 'please provide an initialisation for inducing inputs z')
     if z.ndim < 2:
         z = z[:, np.newaxis]
     if spatial_dims is None:
         spatial_dims = z.ndim - 1
     assert spatial_dims == z.ndim - 1
     self.M = z.shape[0]
     if opt_z:
         self.z = objax.TrainVar(z)  # .reshape(-1, 1)
     else:
         self.z = objax.StateVar(z)
     if conditional in ['DTC', 'dtc']:
         self.conditional_covariance = self.deterministic_training_conditional
     elif conditional in ['FIC', 'FITC', 'fic', 'fitc']:
         self.conditional_covariance = self.fully_independent_conditional
     elif conditional in ['Full', 'full']:
         self.conditional_covariance = self.full_conditional
     else:
         raise NotImplementedError('conditional method not recognised')
     if (not sparse) and (conditional != 'DTC'):
         warn(
             "You chose a non-deterministic conditional, but \'DTC\' will be used because the model is not sparse"
         )
Пример #4
0
 def __init__(self,
              variance,
              lengthscale,
              fix_variance=False,
              fix_lengthscale=False):
     # check whether the parameters are to be optimised
     if fix_lengthscale:
         self.transformed_lengthscale = objax.StateVar(
             softplus_inv(np.array(lengthscale)))
     else:
         self.transformed_lengthscale = objax.TrainVar(
             softplus_inv(np.array(lengthscale)))
     if fix_variance:
         self.transformed_variance = objax.StateVar(
             softplus_inv(np.array(variance)))
     else:
         self.transformed_variance = objax.TrainVar(
             softplus_inv(np.array(variance)))
Пример #5
0
 def __init__(self,
              kernel,
              likelihood,
              X,
              Y,
              Z,
              opt_z=False):
     super().__init__(kernel=kernel,
                      likelihood=likelihood,
                      X=X,
                      Y=Y)
     if Z.ndim < 2:
         Z = Z[:, None]
     if opt_z:
         self.Z = objax.TrainVar(Z)
     else:
         self.Z = objax.StateVar(Z)
     self.num_inducing = Z.shape[0]
     self.posterior_mean = objax.StateVar(np.zeros([self.num_inducing, self.func_dim, 1]))
     self.posterior_variance = objax.StateVar(np.tile(np.eye(self.func_dim), [self.num_inducing, 1, 1]))
     self.posterior_covariance = objax.StateVar(np.eye(self.num_inducing))
Пример #6
0
 def test_state_var(self):
     """Test StateVar behavior."""
     v = objax.StateVar(jn.arange(6, dtype=jn.float32).reshape((2, 3)))
     self.assertEqual(v.value.shape, (2, 3))
     self.assertEqual(v.value.sum(), 15)
     v.value += 2
     self.assertEqual(v.value.sum(), 27)
     v.assign(v.value - 1)
     self.assertEqual(v.value.sum(), 21)
     v.reduce(v.value)
     self.assertEqual(v.value.shape, (3,))
     self.assertEqual(v.value.tolist(), [2.5, 3.5, 4.5])
Пример #7
0
 def __init__(self,
              kernel,
              likelihood,
              X,
              Y,
              func_dim=1):
     if X.ndim < 2:
         X = X[:, None]
     if Y.ndim < 2:
         Y = Y[:, None]
     self.X = np.array(X)
     self.Y = np.array(Y)
     self.kernel = kernel
     self.likelihood = likelihood
     self.num_data = self.X.shape[0]  # number of data
     self.func_dim = func_dim  # number of latent dimensions
     self.obs_dim = Y.shape[1]  # dimensionality of the observations, Y
     self.mask = np.isnan(self.Y).reshape(Y.shape[0], Y.shape[1])
     if isinstance(self.kernel, Independent):
         pseudo_lik_size = self.func_dim  # the multi-latent case
     else:
         pseudo_lik_size = self.obs_dim
     self.pseudo_likelihood_nat1 = objax.StateVar(np.zeros([self.num_data, pseudo_lik_size, 1]))
     self.pseudo_likelihood_nat2 = objax.StateVar(1e-2 * np.tile(np.eye(pseudo_lik_size), [self.num_data, 1, 1]))
     self.pseudo_y = objax.StateVar(np.zeros([self.num_data, pseudo_lik_size, 1]))
     self.pseudo_var = objax.StateVar(1e2 * np.tile(np.eye(pseudo_lik_size), [self.num_data, 1, 1]))
     self.posterior_mean = objax.StateVar(np.zeros([self.num_data, self.func_dim, 1]))
     self.posterior_variance = objax.StateVar(np.tile(np.eye(self.func_dim), [self.num_data, 1, 1]))
     self.ind = np.arange(self.num_data)
     self.num_neighbours = np.ones(self.num_data)
Пример #8
0
 def __init__(self,
              variance=1.0,
              lengthscale=1.0,
              radial_frequency=1.0,
              fix_variance=False):
     self.transformed_lengthscale = objax.TrainVar(
         np.array(softplus_inv(lengthscale)))
     if fix_variance:
         self.transformed_variance = objax.StateVar(
             np.array(softplus_inv(variance)))
     else:
         self.transformed_variance = objax.TrainVar(
             np.array(softplus_inv(variance)))
     self.transformed_radial_frequency = objax.TrainVar(
         np.array(softplus_inv(radial_frequency)))
     super().__init__()
     self.name = 'Subband Matern-1/2'
Пример #9
0
 def test_vars(self):
     t = objax.TrainVar(jn.zeros([1, 2, 3, 2, 1]))
     tv = '\n'.join(['objax.TrainVar(DeviceArray([[[[[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]]],',
                     '              [[[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]]]]], dtype=float32), reduce=reduce_mean)'])
     self.assertEqual(repr(t), tv)
     r = objax.TrainRef(t)
     rv = '\n'.join(['objax.TrainRef(ref=objax.TrainVar(DeviceArray([[[[[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]]],',
                     '              [[[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]]]]], dtype=float32), reduce=reduce_mean))'])
     self.assertEqual(repr(r), rv)
     t = objax.StateVar(jn.zeros([1, 2, 3, 2, 1]))
     tv = '\n'.join(['objax.StateVar(DeviceArray([[[[[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]]],',
                     '              [[[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]]]]], dtype=float32), reduce=reduce_mean)'])
     self.assertEqual(repr(t), tv)
     self.assertEqual(repr(objax.random.Generator().key), 'objax.RandomState(DeviceArray([0, 0], dtype=uint32))')
Пример #10
0
 def test_var_hierarchy(self):
     """Test variable hierarchy."""
     t = objax.TrainVar(jn.zeros(2))
     s = objax.StateVar(jn.zeros(2))
     r = objax.TrainRef(t)
     x = objax.RandomState(0)
     self.assertIsInstance(t, objax.TrainVar)
     self.assertIsInstance(t, objax.BaseVar)
     self.assertNotIsInstance(t, objax.BaseState)
     self.assertIsInstance(s, objax.BaseVar)
     self.assertIsInstance(s, objax.BaseState)
     self.assertNotIsInstance(s, objax.TrainVar)
     self.assertIsInstance(r, objax.BaseVar)
     self.assertIsInstance(r, objax.BaseState)
     self.assertNotIsInstance(r, objax.TrainVar)
     self.assertIsInstance(x, objax.BaseVar)
     self.assertIsInstance(x, objax.BaseState)
     self.assertIsInstance(x, objax.StateVar)
     self.assertNotIsInstance(x, objax.TrainVar)
Пример #11
0
    def __init__(self,
                 kernel,
                 likelihood,
                 X,
                 Y,
                 R=None,
                 Z=None):
        super().__init__(kernel=kernel,
                         likelihood=likelihood,
                         X=X,
                         Y=Y,
                         R=R)
        if Z is None:
            Z = self.X
        else:
            if Z.ndim < 2:
                Z = Z[:, None]
            Z = np.sort(Z, axis=0)
        inf = np.array([[1e10]])
        self.Z = objax.StateVar(np.concatenate([-inf, Z, inf], axis=0))
        self.dz = np.array(np.diff(self.Z.value[:, 0]))
        self.num_transitions = self.dz.shape[0]
        zeros = np.zeros([self.num_transitions, 2 * self.state_dim, 1])
        eyes = np.tile(np.eye(2 * self.state_dim), [self.num_transitions, 1, 1])

        # nat2 = 1e-8 * eyes

        # initialise to match MarkovGP / GP on first step (when Z=X):
        nat2 = index_update(1e-8 * eyes, index[:-1, self.state_dim, self.state_dim], 1e-2)

        # initialise to match old implementation:
        # nat2 = (1 / 99) * eyes

        self.pseudo_likelihood_nat1 = objax.StateVar(zeros)
        self.pseudo_likelihood_nat2 = objax.StateVar(nat2)
        self.pseudo_y = objax.StateVar(zeros)
        self.pseudo_var = objax.StateVar(vmap(inv)(nat2))
        self.posterior_mean = objax.StateVar(zeros)
        self.posterior_variance = objax.StateVar(eyes)
        self.mask = None
        self.conditional_mean = None
        # TODO: if training Z this needs to be done at every training step (as well as sorting and computing dz)
        self.ind, self.num_neighbours = set_z_stats(self.X, self.Z.value)