Ejemplo n.º 1
0
    def find_best_model_initialization(self, num_kernel_samples: int) -> None:
        """
        Test `num_kernel_samples` models with sampled kernel parameters. The model's kernel
        parameters are then set to the sample achieving maximal likelihood.

        :param num_kernel_samples: Number of randomly sampled kernels to evaluate.
        """

        @tf.function
        def evaluate_loss_of_model_parameters() -> tf.Tensor:
            randomize_hyperparameters(self.model)
            return self.model.training_loss()

        squeeze_hyperparameters(self.model)
        current_best_parameters = read_values(self.model)
        min_loss = self.model.training_loss()

        for _ in tf.range(num_kernel_samples):
            try:
                train_loss = evaluate_loss_of_model_parameters()
            except tf.errors.InvalidArgumentError:  # allow badly specified kernel params
                train_loss = 1e100

            if train_loss < min_loss:  # only keep best kernel params
                min_loss = train_loss
                current_best_parameters = read_values(self.model)

        multiple_assign(self.model, current_best_parameters)
Ejemplo n.º 2
0
def test_multiple_assign_updates_correct_values(model, var_update_dict):
    old_value_dict = leaf_components(model).copy()
    multiple_assign(model, var_update_dict)
    for path, variable in leaf_components(model).items():
        if path in var_update_dict.keys():
            np.testing.assert_almost_equal(variable.value().numpy(),
                                           var_update_dict[path],
                                           decimal=7)
        else:
            np.testing.assert_equal(variable.value().numpy(),
                                    old_value_dict[path].value().numpy())
Ejemplo n.º 3
0
def test_multiple_assign_fails_with_invalid_values(model,
                                                   wrong_var_update_dict):
    with pytest.raises(ValueError):
        multiple_assign(model, wrong_var_update_dict)
Ejemplo n.º 4
0
def _copy_kernel(kernel: Kernel) -> Kernel:
    """
    Returns a copy of the input kernel with the same values but different pointers.

    Doesn't check whether or not the tf.Module objects are actually different.
    """
    # Case for when the kernel is a sum of multiple kernels
    if type(kernel) is Sum:
        res = _copy_kernel(kernel.kernels[0])
        for sub_kernel in kernel.kernels[1:]:
            res += _copy_kernel(sub_kernel)
        return res

    # Case for when the kernel is a product of multiple kernels
    if type(kernel) is Product:
        res = _copy_kernel(kernel.kernels[0])
        for sub_kernel in kernel.kernels[1:]:
            res *= _copy_kernel(sub_kernel)
        return res

    # Case for when the kernel is convolutional
    if type(kernel) is Convolutional:
        image_shape = kernel.image_shape
        patch_shape = kernel.patch_shape
        base_kernel = _copy_kernel(kernel.base_kernel)
        colour_channels = kernel.colour_channels
        res = Convolutional(_copy_kernel(base_kernel),
                            image_shape,
                            patch_shape,
                            colour_channels=colour_channels)
        multiple_assign(res, parameter_dict(kernel))
        return res

    # Case for when the kernel is a change-point kernel
    if type(kernel) is ChangePoints:
        kernels = list(map(_copy_kernel, kernel.kernels))
        locations = kernel.locations
        steepness = kernel.steepness
        name = kernel.name
        res = ChangePoints(kernels, locations, steepness=steepness, name=name)
        multiple_assign(res, parameter_dict(kernel))
        return res

    # Case for when the kernel is periodic
    if type(kernel) is Periodic:
        base_kernel = _copy_kernel(kernel.base_kernel)
        period = kernel.period
        res = Periodic(base_kernel, period=period)
        multiple_assign(res, parameter_dict(kernel))
        return res

    # Case for when the kernel is (an instance of) a linear kernel
    if isinstance(kernel, Linear):
        variance = kernel.variance
        active_dims = kernel.active_dims

        options = [Linear, Polynomial]

        correct_classes = [o for o in options if type(kernel) is o]
        assert len(
            correct_classes
        ) == 1, f"Only one class should match. List of correct classes: {correct_classes}"

        if type(kernel) is Polynomial:
            # Calls the constructor of the (only) correct kernel class
            res = correct_classes[0](variance=variance,
                                     active_dims=active_dims,
                                     degree=kernel.degree)
        else:
            # Calls the constructor of the (only) correct kernel class
            res = correct_classes[0](variance=variance,
                                     active_dims=active_dims)
        multiple_assign(res, parameter_dict(kernel))
        return res

    # Case for when the kernel is an instance of a static kernel
    if isinstance(kernel, Static):
        active_dims = kernel.active_dims

        options = [White, Constant]

        correct_classes = [o for o in options if type(kernel) is o]
        assert len(
            correct_classes
        ) == 1, f"Only one class should match. List of correct classes: {correct_classes}"

        # Calls the constructor of the (only) correct kernel class
        res = correct_classes[0](active_dims=active_dims)
        multiple_assign(res, parameter_dict(kernel))
        return res

    # Case for when the kernel is an instance of a stationary kernel
    if isinstance(kernel, Stationary):
        active_dims = kernel.active_dims
        name = kernel.name
        variance = kernel.variance
        lengthscales = kernel.lengthscales

        options = [
            SquaredExponential, Cosine, Exponential, Matern12, Matern32,
            Matern52, RationalQuadratic
        ]

        correct_classes = [o for o in options if type(kernel) is o]
        assert len(correct_classes) == 1, \
            "Only one class should match. List of correct classes: {}".format(correct_classes)

        # Calls the constructor of the (only) correct kernel class
        res = correct_classes[0](variance=variance,
                                 lengthscales=lengthscales,
                                 active_dims=active_dims,
                                 name=name)
        multiple_assign(res, parameter_dict(kernel))
        return res

    raise ValueError(
        f"BNQDflow's copy_kernel function doesn't support this kernel type: {type(kernel)}"
    )