Пример #1
0
    def __init__(self, num_outputs, num_factors=16,
                 rho_init=INITIAL_NOISE_VARIANCE,
                 encoding_type=DEFAULT_ENCODING, **kwargs):

        super(Coregionalization, self).__init__(dimension=1, **kwargs)

        self.encoding_W_flat = IdentityScalarEncoding(
            dimension=num_outputs * num_factors)
        self.encoding_rho = create_encoding(encoding_type, rho_init,
                                            NOISE_VARIANCE_LOWER_BOUND,
                                            NOISE_VARIANCE_UPPER_BOUND,
                                            dimension=1)

        self.num_outputs = num_outputs
        self.num_factors = num_factors

        with self.name_scope():
            self.W_flat_internal = self.params.get(
                "W_internal", shape=(num_outputs * num_factors,),
                init=mx.init.Normal(),  # TODO: Use Xavier initialization here
                dtype=DATA_TYPE)
            self.rho_internal = self.params.get(
                "rho_internal", shape=(1,),
                init=mx.init.Constant(self.encoding_rho.init_val_int),
                dtype=DATA_TYPE)
Пример #2
0
    def __init__(self, initial_mean_value = INITIAL_MEAN_VALUE, **kwargs):
        super(ScalarMeanFunction, self).__init__(**kwargs)

        # Even though we do not apply specific transformation to the mean value
        # we use an encoding to handle in a consistent way the box constraints
        # of Gluon parameters (like bandwidths or residual noise variance)
        self.encoding = IdentityScalarEncoding(
            init_val=initial_mean_value, regularizer=Normal(0.0, 1.0))
        with self.name_scope():
            self.mean_value_internal = register_parameter(
                self.params, 'mean_value', self.encoding)
Пример #3
0
class ScalarMeanFunction(MeanFunction):
    """
    Mean function defined as a scalar (fitted while optimizing the marginal
    likelihood).

    :param initial_mean_value: A scalar to initialize the value of the mean

    """
    def __init__(self, initial_mean_value = INITIAL_MEAN_VALUE, **kwargs):
        super(ScalarMeanFunction, self).__init__(**kwargs)

        # Even though we do not apply specific transformation to the mean value
        # we use an encoding to handle in a consistent way the box constraints
        # of Gluon parameters (like bandwidths or residual noise variance)
        self.encoding = IdentityScalarEncoding(
            init_val=initial_mean_value, regularizer=Normal(0.0, 1.0))
        with self.name_scope():
            self.mean_value_internal = register_parameter(
                self.params, 'mean_value', self.encoding)

    def hybrid_forward(self, F, X, mean_value_internal):
        """
        Actual computation of the scalar mean function
        We compute mean_value * vector_of_ones, whose dimensions are given by
        the the first column of X

        :param F: mx.sym or mx.nd
        :param X: input data of size (n,d) for which we want to compute the
            mean (here, only useful to extract the right dimension)

        """
        mean_value = self.encoding.get(F, mean_value_internal)
        return F.broadcast_mul(F.ones_like(F.slice_axis(
            F.BlockGrad(X), axis=1, begin=0, end=1)), mean_value)

    def param_encoding_pairs(self):
        return [(self.mean_value_internal, self.encoding)]

    def get_mean_value(self):
        return encode_unwrap_parameter(
            mx.nd, self.mean_value_internal, self.encoding).asscalar()

    def set_mean_value(self, mean_value):
        self.encoding.set(self.mean_value_internal, mean_value)

    def get_params(self):
        return {'mean_value': self.get_mean_value()}

    def set_params(self, param_dict):
        self.set_mean_value(param_dict['mean_value'])
Пример #4
0
    def __init__(self, encoding_type=DEFAULT_ENCODING,
                 u1_init=1.0, u3_init=0.0, **kwargs):

        super(FabolasKernelFunction, self).__init__(dimension=1, **kwargs)

        self.encoding_u12 = create_encoding(
            encoding_type, u1_init, COVARIANCE_SCALE_LOWER_BOUND,
            COVARIANCE_SCALE_UPPER_BOUND, 1, None)
        # This is not really needed, but param_encoding_pairs needs an encoding
        # for each parameter
        self.encoding_u3 = IdentityScalarEncoding(init_val=u3_init)
        with self.name_scope():
            self.u1_internal = register_parameter(
                self.params, 'u1', self.encoding_u12)
            self.u2_internal = register_parameter(
                self.params, 'u2', self.encoding_u12)
            self.u3_internal = register_parameter(
                self.params, 'u3', self.encoding_u3)
Пример #5
0
class FabolasKernelFunction(KernelFunction):
    """
    The kernel function proposed in:

        Klein, A., Falkner, S., Bartels, S., Hennig, P., & Hutter, F. (2016).
        Fast Bayesian Optimization of Machine Learning Hyperparameters
        on Large Datasets, in AISTATS 2017.
        ArXiv:1605.07079 [Cs, Stat]. Retrieved from http://arxiv.org/abs/1605.07079

    Please note this is only one of the components of the factorized kernel
    proposed in the paper. This is the finite-rank ("degenerate") kernel for
    modelling data subset fraction sizes. Defined as:

        k(x, y) = (U phi(x))^T (U phi(y)),  x, y in [0, 1],
        phi(x) = [1, (1 - x)^2]^T,  U = [[u1, u3], [0, u2]] upper triangular,
        u1, u2 > 0.
    """

    def __init__(self, encoding_type=DEFAULT_ENCODING,
                 u1_init=1.0, u3_init=0.0, **kwargs):

        super(FabolasKernelFunction, self).__init__(dimension=1, **kwargs)

        self.encoding_u12 = create_encoding(
            encoding_type, u1_init, COVARIANCE_SCALE_LOWER_BOUND,
            COVARIANCE_SCALE_UPPER_BOUND, 1, None)
        # This is not really needed, but param_encoding_pairs needs an encoding
        # for each parameter
        self.encoding_u3 = IdentityScalarEncoding(init_val=u3_init)
        with self.name_scope():
            self.u1_internal = register_parameter(
                self.params, 'u1', self.encoding_u12)
            self.u2_internal = register_parameter(
                self.params, 'u2', self.encoding_u12)
            self.u3_internal = register_parameter(
                self.params, 'u3', self.encoding_u3)

    @staticmethod
    def _compute_factor(F, x, u1, u2, u3):
        tvec = (1.0 - x) ** 2
        return F.concat(
            F.broadcast_add(F.broadcast_mul(tvec, u3), u1),
            F.broadcast_mul(tvec, u2), dim=1)

    def hybrid_forward(self, F, X1, X2, u1_internal, u2_internal, u3_internal):
        X1 = self._check_input_shape(F, X1)

        u1 = self.encoding_u12.get(F, u1_internal)
        u2 = self.encoding_u12.get(F, u2_internal)
        u3 = self.encoding_u3.get(F, u3_internal)

        mat1 = self._compute_factor(F, X1, u1, u2, u3)
        if X2 is X1:
            return F.linalg.syrk(mat1, transpose=False)
        else:
            X2 = self._check_input_shape(F, X2)
            mat2 = self._compute_factor(F, X2, u1, u2, u3)
            return F.dot(mat1, mat2, transpose_a=False, transpose_b=True)

    def _get_pars(self, F, X):
        u1 = encode_unwrap_parameter(F, self.u1_internal, self.encoding_u12, X)
        u2 = encode_unwrap_parameter(F, self.u2_internal, self.encoding_u12, X)
        u3 = encode_unwrap_parameter(F, self.u3_internal, self.encoding_u3, X)
        return (u1, u2, u3)

    def diagonal(self, F, X):
        X = self._check_input_shape(F, X)
        u1, u2, u3 = self._get_pars(F, X)
        mat = self._compute_factor(F, X, u1, u2, u3)
        return F.sum(mat ** 2, axis=1)

    def diagonal_depends_on_X(self):
        return True

    def param_encoding_pairs(self):
        return [
            (self.u1_internal, self.encoding_u12),
            (self.u2_internal, self.encoding_u12),
            (self.u3_internal, self.encoding_u3)
        ]

    def get_params(self):
        values = list(self._get_pars(mx.nd, None))
        keys = ['u1', 'u2', 'u3']
        return {k: v.reshape((1,)).asscalar() for k, v in zip(keys, values)}

    def set_params(self, param_dict):
        self.encoding_u12.set(self.u1_internal, param_dict['u1'])
        self.encoding_u12.set(self.u2_internal, param_dict['u2'])
        self.encoding_u3.set(self.u3_internal, param_dict['u3'])
    def __init__(self,
                 kernel_x: KernelFunction,
                 mean_x: MeanFunction,
                 encoding_type=DEFAULT_ENCODING,
                 alpha_init=1.0,
                 mean_lam_init=0.5,
                 gamma_init=0.5,
                 delta_fixed_value=None,
                 delta_init=0.5,
                 max_metric_value=1.0,
                 **kwargs):
        """
        :param kernel_x: Kernel k_x(x, x') over configs
        :param mean_x: Mean function mu_x(x) over configs
        :param encoding_type: Encoding used for alpha, mean_lam, gamma (positive
            values)
        :param alpha_init: Initial value alpha
        :param mean_lam_init: Initial value mean_lam
        :param gamma_init: Initial value gamma
        :param delta_fixed_value: If not None, delta is fixed to this value, and
            does not become a free parameter
        :param delta_init: Initial value delta (if delta_fixed_value is None)
        :param max_metric_value: Maximum value which metric can attend. This is
            used as upper bound on gamma
        """

        super(ExponentialDecayResourcesKernelFunction,
              self).__init__(dimension=kernel_x.dimension + 1, **kwargs)
        self.kernel_x = kernel_x
        self.mean_x = mean_x
        # alpha, mean_lam are parameters of a Gamma distribution, where alpha is
        # a scale parameter, and
        #   E[lambda] = mean_lam, Var[lambda] = mean_lam ** 2 / alpha
        alpha_lower, alpha_upper = 1e-6, 250.0
        alpha_init = self._wrap_initvals(alpha_init, alpha_lower, alpha_upper)
        self.encoding_alpha = create_encoding(encoding_type, alpha_init,
                                              alpha_lower, alpha_upper, 1,
                                              None)
        mean_lam_lower, mean_lam_upper = 1e-4, 50.0
        mean_lam_init = self._wrap_initvals(mean_lam_init, mean_lam_lower,
                                            mean_lam_upper)
        self.encoding_mean_lam = create_encoding(encoding_type, mean_lam_init,
                                                 mean_lam_lower,
                                                 mean_lam_upper, 1, None)
        # If f(x, 0) is the metric value at r -> 0, f(x) at r -> infty,
        # then f(x, 0) = gamma (for delta = 1), or f(x, 0) = gamma + f(x) for
        # delta = 0. gamma should not be largest than the maximum metric
        # value.
        gamma_lower = max_metric_value * 0.0001
        gamma_upper = max_metric_value
        gamma_init = self._wrap_initvals(gamma_init, gamma_lower, gamma_upper)
        self.encoding_gamma = create_encoding(encoding_type, gamma_init,
                                              gamma_lower, gamma_upper, 1,
                                              None)
        if delta_fixed_value is None:
            delta_init = self._wrap_initvals(delta_init, 0.0, 1.0)
            self.encoding_delta = IdentityScalarEncoding(constr_lower=0.0,
                                                         constr_upper=1.0,
                                                         init_val=delta_init)
        else:
            assert 0.0 <= delta_fixed_value <= 1.0, \
                "delta_fixed_value = {}, must lie in [0, 1]".format(
                    delta_fixed_value)
            self.encoding_delta = None
            self.delta_fixed_value = delta_fixed_value

        with self.name_scope():
            self.alpha_internal = register_parameter(self.params, "alpha",
                                                     self.encoding_alpha)
            self.mean_lam_internal = register_parameter(
                self.params, "mean_lam", self.encoding_mean_lam)
            self.gamma_internal = register_parameter(self.params, "gamma",
                                                     self.encoding_gamma)
            if delta_fixed_value is None:
                self.delta_internal = register_parameter(
                    self.params, "delta", self.encoding_delta)
class ExponentialDecayResourcesKernelFunction(KernelFunction):
    """
    Variant of the kernel function for modeling exponentially decaying
    learning curves, proposed in:

        Swersky, K., Snoek, J., & Adams, R. P. (2014).
        Freeze-Thaw Bayesian Optimization.
        ArXiv:1406.3896 [Cs, Stat).
        Retrieved from http://arxiv.org/abs/1406.3896

    The argument in that paper actually justifies using a non-zero mean
    function (see ExponentialDecayResourcesMeanFunction) and centralizing
    the kernel proposed there. This is done here. Details in:

        Tiao, Klein, Archambeau, Seeger (2020)
        Model-based Asynchronous Hyperparameter Optimization
        https://arxiv.org/abs/2003.10865

    We implement a new family of kernel functions, for which the additive
    Freeze-Thaw kernel is one instance (delta = 0).
    The kernel has parameters alpha, mean_lam, gamma > 0, and delta in [0, 1].
    Note that beta = alpha / mean_lam is used in the Freeze-Thaw paper (the
    Gamma distribution over lambda is parameterized differently).
    The additive Freeze-Thaw kernel is obtained for delta = 0 (use
    delta_fixed_value = 0).

    In fact, this class is configured with a kernel and a mean function over
    inputs x (dimension d) and represents a kernel (and mean function) over
    inputs (x, r) (dimension d + 1), where the resource attribute r >= 0 is
    last.

    """
    def __init__(self,
                 kernel_x: KernelFunction,
                 mean_x: MeanFunction,
                 encoding_type=DEFAULT_ENCODING,
                 alpha_init=1.0,
                 mean_lam_init=0.5,
                 gamma_init=0.5,
                 delta_fixed_value=None,
                 delta_init=0.5,
                 max_metric_value=1.0,
                 **kwargs):
        """
        :param kernel_x: Kernel k_x(x, x') over configs
        :param mean_x: Mean function mu_x(x) over configs
        :param encoding_type: Encoding used for alpha, mean_lam, gamma (positive
            values)
        :param alpha_init: Initial value alpha
        :param mean_lam_init: Initial value mean_lam
        :param gamma_init: Initial value gamma
        :param delta_fixed_value: If not None, delta is fixed to this value, and
            does not become a free parameter
        :param delta_init: Initial value delta (if delta_fixed_value is None)
        :param max_metric_value: Maximum value which metric can attend. This is
            used as upper bound on gamma
        """

        super(ExponentialDecayResourcesKernelFunction,
              self).__init__(dimension=kernel_x.dimension + 1, **kwargs)
        self.kernel_x = kernel_x
        self.mean_x = mean_x
        # alpha, mean_lam are parameters of a Gamma distribution, where alpha is
        # a scale parameter, and
        #   E[lambda] = mean_lam, Var[lambda] = mean_lam ** 2 / alpha
        alpha_lower, alpha_upper = 1e-6, 250.0
        alpha_init = self._wrap_initvals(alpha_init, alpha_lower, alpha_upper)
        self.encoding_alpha = create_encoding(encoding_type, alpha_init,
                                              alpha_lower, alpha_upper, 1,
                                              None)
        mean_lam_lower, mean_lam_upper = 1e-4, 50.0
        mean_lam_init = self._wrap_initvals(mean_lam_init, mean_lam_lower,
                                            mean_lam_upper)
        self.encoding_mean_lam = create_encoding(encoding_type, mean_lam_init,
                                                 mean_lam_lower,
                                                 mean_lam_upper, 1, None)
        # If f(x, 0) is the metric value at r -> 0, f(x) at r -> infty,
        # then f(x, 0) = gamma (for delta = 1), or f(x, 0) = gamma + f(x) for
        # delta = 0. gamma should not be largest than the maximum metric
        # value.
        gamma_lower = max_metric_value * 0.0001
        gamma_upper = max_metric_value
        gamma_init = self._wrap_initvals(gamma_init, gamma_lower, gamma_upper)
        self.encoding_gamma = create_encoding(encoding_type, gamma_init,
                                              gamma_lower, gamma_upper, 1,
                                              None)
        if delta_fixed_value is None:
            delta_init = self._wrap_initvals(delta_init, 0.0, 1.0)
            self.encoding_delta = IdentityScalarEncoding(constr_lower=0.0,
                                                         constr_upper=1.0,
                                                         init_val=delta_init)
        else:
            assert 0.0 <= delta_fixed_value <= 1.0, \
                "delta_fixed_value = {}, must lie in [0, 1]".format(
                    delta_fixed_value)
            self.encoding_delta = None
            self.delta_fixed_value = delta_fixed_value

        with self.name_scope():
            self.alpha_internal = register_parameter(self.params, "alpha",
                                                     self.encoding_alpha)
            self.mean_lam_internal = register_parameter(
                self.params, "mean_lam", self.encoding_mean_lam)
            self.gamma_internal = register_parameter(self.params, "gamma",
                                                     self.encoding_gamma)
            if delta_fixed_value is None:
                self.delta_internal = register_parameter(
                    self.params, "delta", self.encoding_delta)

    @staticmethod
    def _wrap_initvals(init, lower, upper):
        return max(min(init, upper * 0.999), lower * 1.001)

    @staticmethod
    def _compute_kappa(F, x, alpha, mean_lam):
        beta = alpha / mean_lam
        return F.broadcast_power(
            F.broadcast_div(beta, F.broadcast_add(x, beta)), alpha)

    def _compute_terms(self,
                       F,
                       X,
                       alpha,
                       mean_lam,
                       gamma,
                       delta,
                       ret_mean=False):
        dim = self.kernel_x.dimension
        cfg = F.slice_axis(X, axis=1, begin=0, end=dim)
        res = F.slice_axis(X, axis=1, begin=dim, end=None)
        kappa = self._compute_kappa(F, res, alpha, mean_lam)
        kr_pref = F.reshape(gamma, shape=(1, 1))
        if ret_mean or (self.encoding_delta is not None) or delta > 0.0:
            mean = self.mean_x(cfg)
        else:
            mean = None
        if self.encoding_delta is not None:
            kr_pref = F.broadcast_sub(kr_pref, F.broadcast_mul(delta, mean))
        elif delta > 0.0:
            kr_pref = F.broadcast_sub(kr_pref, mean * delta)
        return cfg, res, kappa, kr_pref, mean

    @staticmethod
    def _unwrap(F, X, kwargs, key, enc, var_internal):
        return enc.get(
            F,
            kwargs.get(get_name_internal(key),
                       unwrap_parameter(F, var_internal, X)))

    def _get_params(self, F, X, **kwargs):
        alpha = self._unwrap(F, X, kwargs, 'alpha', self.encoding_alpha,
                             self.alpha_internal)
        mean_lam = self._unwrap(F, X, kwargs, 'mean_lam',
                                self.encoding_mean_lam, self.mean_lam_internal)
        gamma = self._unwrap(F, X, kwargs, 'gamma', self.encoding_gamma,
                             self.gamma_internal)
        if self.encoding_delta is not None:
            delta = F.reshape(self._unwrap(F, X, kwargs, 'delta',
                                           self.encoding_delta,
                                           self.delta_internal),
                              shape=(1, 1))
        else:
            delta = self.delta_fixed_value
        return (alpha, mean_lam, gamma, delta)

    def hybrid_forward(self, F, X1, X2, **kwargs):
        alpha, mean_lam, gamma, delta = self._get_params(F, X1, **kwargs)

        cfg1, res1, kappa1, kr_pref1, _ = self._compute_terms(
            F, X1, alpha, mean_lam, gamma, delta)
        if X2 is not X1:
            cfg2, res2, kappa2, kr_pref2, _ = self._compute_terms(
                F, X2, alpha, mean_lam, gamma, delta)
        else:
            cfg2, res2, kappa2, kr_pref2 = cfg1, res1, kappa1, kr_pref1
        res2 = F.reshape(res2, shape=(1, -1))
        kappa2 = F.reshape(kappa2, shape=(1, -1))
        kr_pref2 = F.reshape(kr_pref2, shape=(1, -1))
        kappa12 = self._compute_kappa(F, F.broadcast_add(res1, res2), alpha,
                                      mean_lam)

        kmat_res = F.broadcast_sub(kappa12, F.broadcast_mul(kappa1, kappa2))
        kmat_res = F.broadcast_mul(kr_pref1,
                                   F.broadcast_mul(kr_pref2, kmat_res))
        kmat_x = self.kernel_x(cfg1, cfg2)
        if self.encoding_delta is None:
            if delta > 0.0:
                tmpmat = F.broadcast_add(
                    kappa1, F.broadcast_sub(kappa2, kappa12 * delta))
                tmpmat = tmpmat * (-delta) + 1.0
            else:
                tmpmat = 1.0
        else:
            tmpmat = F.broadcast_add(
                kappa1, F.broadcast_sub(kappa2,
                                        F.broadcast_mul(kappa12, delta)))
            tmpmat = F.broadcast_mul(tmpmat, -delta) + 1.0

        return kmat_x * tmpmat + kmat_res

    def diagonal(self, F, X):
        alpha, mean_lam, gamma, delta = self._get_params(F, X)

        cfg, res, kappa, kr_pref, _ = self._compute_terms(
            F, X, alpha, mean_lam, gamma, delta)
        kappa2 = self._compute_kappa(F, res * 2, alpha, mean_lam)

        kdiag_res = F.broadcast_sub(kappa2, F.square(kappa))
        kdiag_res = F.reshape(F.broadcast_mul(kdiag_res, F.square(kr_pref)),
                              shape=(-1, ))
        kdiag_x = self.kernel_x.diagonal(F, cfg)
        if self.encoding_delta is None:
            if delta > 0.0:
                tmpvec = F.broadcast_sub(kappa * 2, kappa2 * delta)
                tmpvec = F.reshape(tmpvec * (-delta) + 1.0, shape=(-1, ))
            else:
                tmpvec = 1.0
        else:
            tmpvec = F.broadcast_sub(kappa * 2, F.broadcast_mul(kappa2, delta))
            tmpvec = F.reshape(F.broadcast_mul(tmpvec, -delta) + 1.0,
                               shape=(-1, ))

        return kdiag_x * tmpvec + kdiag_res

    def diagonal_depends_on_X(self):
        return True

    def param_encoding_pairs(self):
        enc_list = [(self.alpha_internal, self.encoding_alpha),
                    (self.mean_lam_internal, self.encoding_mean_lam),
                    (self.gamma_internal, self.encoding_gamma)]
        if self.encoding_delta is not None:
            enc_list.append((self.delta_internal, self.encoding_delta))
        enc_list.extend(self.kernel_x.param_encoding_pairs())
        enc_list.extend(self.mean_x.param_encoding_pairs())
        return enc_list

    def mean_function(self, F, X):
        alpha, mean_lam, gamma, delta = self._get_params(F, X)
        cfg, res, kappa, kr_pref, mean = self._compute_terms(F,
                                                             X,
                                                             alpha,
                                                             mean_lam,
                                                             gamma,
                                                             delta,
                                                             ret_mean=True)

        return F.broadcast_add(mean, F.broadcast_mul(kappa, kr_pref))

    def get_params(self):
        """
        Parameter keys are alpha, mean_lam, gamma, delta (only if not fixed
        to delta_fixed_value), as well as those of self.kernel_x (prefix
        'kernelx_') and of self.mean_x (prefix 'meanx_').

        """
        values = list(self._get_params(mx.nd, None))
        keys = ['alpha', 'mean_lam', 'gamma', 'delta']
        if self.encoding_delta is None:
            values.pop()
            keys.pop()
        result = {k: v.reshape((1, )).asscalar() for k, v in zip(keys, values)}
        for pref, func in [('kernelx_', self.kernel_x),
                           ('meanx_', self.mean_x)]:
            result.update({(pref + k): v
                           for k, v in func.get_params().items()})
        return result

    def set_params(self, param_dict):
        for pref, func in [('kernelx_', self.kernel_x),
                           ('meanx_', self.mean_x)]:
            len_pref = len(pref)
            stripped_dict = {
                k[len_pref:]: v
                for k, v in param_dict.items() if k.startswith(pref)
            }
            func.set_params(stripped_dict)
        self.encoding_alpha.set(self.alpha_internal, param_dict['alpha'])
        self.encoding_mean_lam.set(self.mean_lam_internal,
                                   param_dict['mean_lam'])
        self.encoding_gamma.set(self.gamma_internal, param_dict['gamma'])
        if self.encoding_delta is not None:
            self.encoding_delta.set(self.delta_internal, param_dict['delta'])
Пример #8
0
class Coregionalization(KernelFunction):
    """
    k(i, j) = K_{ij}, where K = W W^T + diag(rho).
    """
    def __init__(self, num_outputs, num_factors=16,
                 rho_init=INITIAL_NOISE_VARIANCE,
                 encoding_type=DEFAULT_ENCODING, **kwargs):

        super(Coregionalization, self).__init__(dimension=1, **kwargs)

        self.encoding_W_flat = IdentityScalarEncoding(
            dimension=num_outputs * num_factors)
        self.encoding_rho = create_encoding(encoding_type, rho_init,
                                            NOISE_VARIANCE_LOWER_BOUND,
                                            NOISE_VARIANCE_UPPER_BOUND,
                                            dimension=1)

        self.num_outputs = num_outputs
        self.num_factors = num_factors

        with self.name_scope():
            self.W_flat_internal = self.params.get(
                "W_internal", shape=(num_outputs * num_factors,),
                init=mx.init.Normal(),  # TODO: Use Xavier initialization here
                dtype=DATA_TYPE)
            self.rho_internal = self.params.get(
                "rho_internal", shape=(1,),
                init=mx.init.Constant(self.encoding_rho.init_val_int),
                dtype=DATA_TYPE)

    @staticmethod
    def _meshgrid(F, a, b):
        """
        Return coordinate matrices from coordinate vectors.

        Like https://docs.scipy.org/doc/numpy/reference/generated/numpy.meshgrid.html
        (with Cartesian indexing), but only supports two coordinate vectors as input.

        :param a: 1-D array representing the coordinates of a grid (length n) 
        :param b: 1-D array representing the coordinates of a grid (length m) 
        :return: coordinate matrix. 3-D array of shape (2, m, n).
        """
        aa = F.broadcast_mul(F.ones_like(F.expand_dims(a, axis=-1)), b)
        bb = F.broadcast_mul(F.ones_like(F.expand_dims(b, axis=-1)), a)
        return F.stack(bb, F.transpose(aa), axis=0)

    def _compute_gram_matrix(self, F, W_flat, rho):
        W = F.reshape(W_flat, shape=(self.num_outputs, self.num_factors))
        rho_vec = F.broadcast_mul(rho, F.ones(self.num_outputs, dtype=DATA_TYPE))
        return F.linalg.syrk(W) + F.diag(rho_vec)

    def hybrid_forward(self, F, ind1, ind2, W_flat_internal, rho_internal):
        W_flat = self.encoding_W_flat.get(F, W_flat_internal)
        rho = self.encoding_rho.get(F, rho_internal)
        K = self._compute_gram_matrix(F, W_flat, rho)
        ind1 = self._check_input_shape(F, ind1)
        if ind2 is not ind1:
            ind2 = self._check_input_shape(F, ind2)
        ind = self._meshgrid(F, ind1, ind2)
        return F.transpose(F.squeeze(F.gather_nd(K, ind)))

    def diagonal(self, F, ind):
        ind = self._check_input_shape(F, ind)
        W_flat = self.encoding_W_flat.get(F, unwrap_parameter(F, self.W_flat_internal, ind))
        rho = self.encoding_rho.get(F, unwrap_parameter(F, self.rho_internal, ind))
        K = self._compute_gram_matrix(F, W_flat, rho)
        K_diag = F.diag(K)
        return F.take(K_diag, ind)

    def diagonal_depends_on_X(self):
        return True

    def param_encoding_pairs(self):
        return [
            (self.W_flat_internal, self.encoding_W_flat),
            (self.rho_internal, self.encoding_rho),
        ]