Esempio n. 1
0
def fit(X: np.ndarray,
        z: np.ndarray,
        weights: np.ndarray,
        sp_num: np.ndarray,
        n_inducing: int,
        n_latent: int,
        log_folder: str,
        use_berman_turner: bool = True,
        X_thin: Optional[np.ndarray] = None,
        n_thin_inducing: Optional[int] = None,
        learning_rate: float = 0.01,
        steps: int = 100000,
        batch_size: int = 50000,
        save_opt_state: bool = False,
        save_every: Optional[int] = 1000,
        fix_thin_inducing: bool = False,
        cov_alpha: Optional[float] = None,
        thin_alpha: Optional[float] = 1.,
        fix_zero_w_prior_mean: bool = True,
        separate_w_prior_vars: bool = True):

    n_cov = X.shape[1]
    n_data = X.shape[0]
    n_out = len(np.unique(sp_num))

    Z = find_starting_z(X[(z == 0) & (sp_num == np.unique(sp_num)[0])],
                        n_inducing)

    if X_thin is not None:
        # Make sure we were given how many thinning inducing to use
        assert n_thin_inducing is not None
        Z_thin = find_starting_z(
            X_thin[(z == 0) & (sp_num == np.unique(sp_num)[0])],
            n_thin_inducing)
    else:
        Z_thin = None

    log_cov_alpha = np.log(cov_alpha) if cov_alpha is not None else tf.cast(
        tf.constant(np.log(np.sqrt(2. / n_latent))), tf.float32)
    log_thin_alpha = np.log(thin_alpha)

    start_theta = initialise_theta(Z,
                                   n_latent,
                                   n_cov,
                                   n_out,
                                   Z_thin=Z_thin,
                                   log_cov_alpha=log_cov_alpha,
                                   log_thin_alpha=log_thin_alpha,
                                   separate_w_prior_vars=separate_w_prior_vars)

    if fix_thin_inducing:
        # Remove them from the theta dict of parameters to optimise
        start_theta = {x: y for x, y in start_theta.items() if x != 'thin_Zs'}

    if fix_zero_w_prior_mean:
        # Remove them from the theta dict of parameters to optimise
        start_theta = {
            x: y
            for x, y in start_theta.items() if x != 'w_prior_mean'
        }

    flat_theta, summary = flatten_and_summarise_tf(**start_theta)

    log_folder = os.path.join(
        log_folder,
        create_path_with_variables(lr=learning_rate,
                                   batch_size=batch_size,
                                   steps=steps))

    os.makedirs(log_folder, exist_ok=True)

    opt_step_fun = partial(adam_step, step_size_fun=lambda t: learning_rate)
    opt_state = initialise_state(flat_theta.shape[0])

    flat_theta = flat_theta.numpy()

    to_optimise = partial(objective_and_grad,
                          n_data=n_data,
                          n_latent=n_latent,
                          summary=summary,
                          use_berman_turner=use_berman_turner,
                          log_cov_alpha=log_cov_alpha)

    if fix_thin_inducing:

        to_optimise = partial(to_optimise,
                              thin_Zs=tf.constant(
                                  np.expand_dims(Z_thin.astype(np.float32),
                                                 axis=0)))

    n_w_means = n_out if separate_w_prior_vars else 1

    if fix_zero_w_prior_mean:
        to_optimise = partial(to_optimise,
                              w_prior_mean=tf.zeros((n_w_means, n_latent)))

    full_data = {'X': X, 'sp_num': sp_num, 'z': z, 'weights': weights}

    log_file = os.path.join(log_folder, 'losses.txt')

    if X_thin is not None:
        full_data['X_thin'] = X_thin
    else:
        to_optimise = partial(to_optimise, X_thin=None)

    loss_log_file = open(log_file, 'w')

    additional_vars = {}

    if fix_thin_inducing:
        # Store thin Zs for callback to save
        additional_vars['thin_Zs'] = np.expand_dims(Z_thin, axis=0)

    if fix_zero_w_prior_mean:
        additional_vars['w_prior_mean'] = np.zeros((n_w_means, n_latent))

    additional_vars['log_cov_alpha'] = log_cov_alpha
    additional_vars['log_thin_alpha'] = log_thin_alpha

    def opt_callback(step: int, loss: float, theta: np.ndarray,
                     grad: np.ndarray, opt_state: Any):

        # Save theta and the gradients
        save_theta_and_grad_callback(step,
                                     loss,
                                     theta,
                                     grad,
                                     opt_state,
                                     log_folder,
                                     summary,
                                     save_every,
                                     additional_vars=additional_vars)

        # Log the loss
        loss_log_callback(step, loss, theta, grad, opt_state, loss_log_file)

    flat_theta, loss_log, _ = optimise_minibatching(full_data,
                                                    to_optimise,
                                                    opt_step_fun,
                                                    opt_state,
                                                    flat_theta,
                                                    batch_size,
                                                    steps,
                                                    X.shape[0],
                                                    callback=opt_callback)

    # Cast to float32
    flat_theta = flat_theta.astype(np.float32)

    final_theta = reconstruct_np(flat_theta, summary)

    if fix_thin_inducing:
        final_theta['thin_Zs'] = np.expand_dims(Z_thin, axis=0)

    if fix_zero_w_prior_mean:
        final_theta['w_prior_mean'] = np.zeros((1, n_latent))

    final_theta['log_cov_alpha'] = log_cov_alpha
    final_theta['log_thin_alpha'] = log_thin_alpha

    return final_theta
Esempio n. 2
0
def fit(
    X: np.ndarray,
    y: np.ndarray,
    n_inducing: int = 100,
    n_latent: int = 10,
    kernel: str = "matern_3/2",
    # Gamma priors (note tfp uses "concentration rate" parameterisation):
    kernel_lengthscale_prior: Tuple[float, float] = (3, 1 / 3),
    bias_variance_prior: Tuple[float, float] = (3 / 2, 3 / 2),
    w_variance_prior: Tuple[float, float] = (3 / 2, 3 / 2),
    # Normal priors
    w_mean_prior: Tuple[float, float] = (0, 1),
    bias_mean_prior: Tuple[float, float] = (0, 1),
    random_seed: int = 2,
    test_run: bool = False,
    total_kernel_variance=6.0,
    verbose=False,
) -> MOGPResult:

    np.random.seed(random_seed)

    # Note that input _must_ be scaled. Some way to enforce that?
    kernel_fun = kern_lookup[kernel]

    n_cov = X.shape[1]
    n_out = y.shape[1]

    # Set initial values
    start_lengthscales = np.random.uniform(2.0, 4.0, size=(n_latent, n_cov)).astype(
        np.float32
    )

    Z = find_starting_z(X, n_inducing)
    Z = np.tile(Z, (n_latent, 1, 1))
    Z = Z.astype(np.float32)

    start_kernel_funs = get_kernel_funs(
        kernel_fun,
        tf.constant(start_lengthscales),
        total_variance=tf.constant(total_kernel_variance),
    )

    init_Ls = np.stack(
        [
            get_initial_values_from_kernel(tf.constant(cur_z), cur_kernel_fun)
            for cur_z, cur_kernel_fun in zip(Z, start_kernel_funs)
        ]
    )

    init_ms = np.zeros((n_latent, n_inducing))
    w_prior_var_init = np.ones((n_latent, 1)) * 1.0
    w_prior_mean_init = np.zeros((n_latent, 1))

    start_intercept_means = np.zeros(n_out)
    start_intercept_var = np.ones(n_out)
    intercept_prior_var_init = np.array(0.4)

    init_theta = {
        "L_elts": init_Ls,
        "mu": init_ms,
        "w_prior_var": w_prior_var_init,
        "w_prior_mean": w_prior_mean_init,
        "intercept_means": start_intercept_means,
        "intercept_vars": start_intercept_var,
        "intercept_prior_var": intercept_prior_var_init,
        "intercept_prior_mean": np.array(0.0),
        "w_means": np.random.randn(n_latent, n_out) * 0.01,
        "w_vars": np.ones((n_latent, n_out)),
        "lscales": np.sqrt(start_lengthscales),
        "Z": Z,
    }

    # Make same type
    init_theta = {x: tf.constant(y.astype(np.float32)) for x, y in init_theta.items()}

    flat_theta, summary = flatten_and_summarise_tf(**init_theta)

    X = tf.constant(X.astype(np.float32))
    y = tf.constant(y.astype(np.float32))

    lscale_prior = tfp.distributions.Gamma(*kernel_lengthscale_prior)
    bias_var_prior = tfp.distributions.Gamma(*bias_variance_prior)
    w_var_prior = tfp.distributions.Gamma(*w_variance_prior)

    w_m_prior = tfp.distributions.Normal(*w_mean_prior)
    bias_m_prior = tfp.distributions.Normal(*bias_mean_prior)

    # TODO: Think about priors for W?

    def to_minimize_with_grad(x):

        with tf.GradientTape() as tape:

            x_tf = tf.constant(x)
            x_tf = tf.cast(x_tf, tf.float32)

            tape.watch(x_tf)

            theta = reconstruct_tf(x_tf, summary)

            # Square the important parameters
            (lscales, w_prior_var, intercept_vars, intercept_prior_var, w_vars) = (
                theta["lscales"] ** 2,
                theta["w_prior_var"] ** 2,
                theta["intercept_vars"] ** 2,
                theta["intercept_prior_var"] ** 2,
                theta["w_vars"] ** 2,
            )

            if verbose:
                print(lscales)
                print(intercept_prior_var)
                print(w_prior_var)
                print(theta["w_prior_mean"])
                print(theta["intercept_prior_mean"])

            Ls = create_ls(theta["L_elts"], n_inducing, n_latent)

            kern_funs = get_kernel_funs(
                kernel_fun,
                lscales,
                total_variance=tf.constant(total_kernel_variance, dtype=tf.float32),
            )

            kl = compute_kl_term(
                theta["mu"],
                Ls,
                kern_funs,
                theta["Z"],
                theta["w_means"],
                w_vars,
                theta["w_prior_mean"],
                w_prior_var,
                theta["intercept_means"],
                intercept_vars,
                theta["intercept_prior_mean"],
                intercept_prior_var,
            )

            lik = compute_likelihood_term(
                X,
                y,
                theta["Z"],
                theta["mu"],
                Ls,
                kern_funs,
                theta["w_means"],
                w_vars,
                theta["intercept_means"],
                intercept_vars,
            )

            objective = -(lik - kl)

            objective = objective - (
                tf.reduce_sum(lscale_prior.log_prob(lscales))
                + bias_var_prior.log_prob(intercept_prior_var)
                + tf.reduce_sum(w_var_prior.log_prob(w_prior_var))
                + bias_m_prior.log_prob(theta["intercept_prior_mean"])
                + tf.reduce_sum(w_m_prior.log_prob(theta["w_prior_mean"]))
            )

            grad = tape.gradient(objective, x_tf)

        if verbose:
            print(objective, np.linalg.norm(grad.numpy()))

        return (objective.numpy().astype(np.float64), grad.numpy().astype(np.float64))

    if test_run:
        additional_args = {"tol": 1}
    else:
        additional_args = {}

    result = minimize(
        to_minimize_with_grad,
        flat_theta,
        jac=True,
        method="L-BFGS-B",
        **additional_args
    )

    final_theta = reconstruct_tf(result.x, summary)
    final_theta = {x: tf.cast(y, tf.float32) for x, y in final_theta.items()}

    # Build the results
    fit_result = MOGPResult(
        L_elts=final_theta["L_elts"],
        mu=final_theta["mu"],
        kernel=kernel,
        lengthscales=final_theta["lscales"] ** 2,
        intercept_means=final_theta["intercept_means"],
        intercept_vars=final_theta["intercept_vars"] ** 2,
        w_means=final_theta["w_means"],
        w_vars=final_theta["w_vars"] ** 2,
        Z=final_theta["Z"],
        w_prior_means=final_theta["w_prior_mean"],
        w_prior_vars=final_theta["w_prior_var"] ** 2,
        intercept_prior_mean=final_theta["intercept_prior_mean"],
        intercept_prior_var=final_theta["intercept_prior_var"] ** 2,
        total_kernel_variance=tf.constant(total_kernel_variance, tf.float32),
    )

    return fit_result
def fit(X: np.ndarray,
        y: np.ndarray,
        n_inducing: int = 100,
        n_latent: int = 10,
        kernel: str = 'matern_3/2',
        random_seed: int = 2):

    # TODO: This is copied from the mogp_classifier.
    # Maybe instead make it a function of some sort?
    np.random.seed(random_seed)

    # Note that input _must_ be scaled. Some way to enforce that?
    kernel_fun = kern_lookup[kernel]

    n_cov = X.shape[1]
    n_out = y.shape[1]

    # Set initial values
    start_lengthscales = np.random.uniform(2., 4., size=(n_latent, n_cov))

    Z = find_starting_z(X, n_inducing)
    Z = np.tile(Z, (n_latent, 1, 1))

    start_kernel_funs = get_kernel_funs(kernel_fun,
                                        np.sqrt(start_lengthscales))

    init_Ls = np.stack([
        get_initial_values_from_kernel(cur_z, cur_kernel_fun)
        for cur_z, cur_kernel_fun in zip(Z, start_kernel_funs)
    ])

    init_ms = np.zeros((n_latent, n_inducing))

    start_prior_cov = np.eye(n_latent)
    start_prior_mean = np.zeros(n_latent)
    start_prior_cov_elts = corr_mogp.get_initial_w_elements(
        start_prior_mean, start_prior_cov, n_out)

    start_w_cov_elts = rep_vector(start_prior_cov_elts, n_out)

    init_w_means = np.random.randn(n_out, n_latent)

    start_theta = {
        'mu': init_ms,
        'L_elts': init_Ls,
        'w_means': init_w_means,
        'w_cov_elts': start_w_cov_elts,
        'lengthscales': start_lengthscales,
        'w_prior_cov_elts': start_prior_cov_elts,
        'w_prior_mean': start_prior_mean,
        'Z': Z
    }

    flat_start_theta, summary = flatten_and_summarise_tf(**start_theta)

    X_tf = tf.constant(X.astype(np.float32))
    y_tf = tf.constant(y.astype(np.float32))

    def extract_cov_matrices(theta):

        w_covs = create_pos_def_mat_from_elts_batch(theta['w_cov_elts'],
                                                    n_latent,
                                                    n_out,
                                                    jitter=JITTER)

        Ls = mogp.create_ls(theta['L_elts'], n_inducing, n_latent)

        w_prior_cov = create_pos_def_mat_from_elts(theta['w_prior_cov_elts'],
                                                   n_latent,
                                                   jitter=JITTER)

        return w_covs, Ls, w_prior_cov

    def calculate_objective(theta):

        w_covs, Ls, w_prior_cov = extract_cov_matrices(theta)

        print(np.round(covar_to_corr(w_prior_cov.numpy()), 2))
        print(np.round(theta['lengthscales'].numpy()**2, 2))

        kernel_funs = get_kernel_funs(kernel_fun, theta['lengthscales']**2)

        cur_objective = corr_mogp.compute_default_objective(
            X_tf, y_tf, theta['Z'], theta['mu'], Ls, theta['w_means'], w_covs,
            kernel_funs, bernoulli_probit_lik, theta['w_prior_mean'],
            w_prior_cov)

        # Add prior
        lscale_prior = tfp.distributions.Gamma(3, 1 / 3).log_prob(
            theta['lengthscales']**2)

        return cur_objective + tf.reduce_sum(lscale_prior)

    def to_minimize(flat_theta):

        flat_theta = tf.constant(flat_theta)
        flat_theta = tf.cast(flat_theta, tf.float32)

        with tf.GradientTape() as tape:

            tape.watch(flat_theta)

            theta = reconstruct_tf(flat_theta, summary)

            objective = -calculate_objective(theta)

            grad = tape.gradient(objective, flat_theta)

        print(objective, np.linalg.norm(grad.numpy()))

        return (objective.numpy().astype(np.float64),
                grad.numpy().astype(np.float64))

    result = minimize(to_minimize,
                      flat_start_theta,
                      jac=True,
                      method='L-BFGS-B')

    final_theta = reconstruct_tf(result.x.astype(np.float32), summary)

    w_covs, Ls, w_prior_cov = extract_cov_matrices(final_theta)

    return CorrelatedMOGPResult(
        Ls=Ls,
        mu=final_theta['mu'].numpy(),
        kernel=kernel,
        lengthscales=final_theta['lengthscales'].numpy()**2,
        w_means=final_theta['w_means'].numpy(),
        w_cov=w_covs.numpy(),
        Z=final_theta['Z'].numpy(),
        w_prior_means=final_theta['w_prior_mean'].numpy(),
        w_prior_cov=w_prior_cov.numpy())
Esempio n. 4
0
    return lik - total_kl


# Get the MOGP init values:
ms, lscales, alphas, kerns, w_means, w_vars, init_ls = get_mogp_initial_values(
    n_cov, n_latent, n_inducing)

# Get the latent init values
site_means, site_l_elts, b_mat = get_latent_initial_values(n_latent_site)

start_theta, summary = flatten_and_summarise_tf(**{
    'env_ms': ms,
    'env_l_elts': tf.stack(init_ls),
    'lscales': lscales,
    'w_means': w_means,
    'w_vars': w_vars,
    'site_means': site_means,
    'site_l_elts': site_l_elts,
    'b_mat': b_mat
})


def to_minimize(x):

    theta = reconstruct_tf(x, summary)

    # TODO: Check initial values are still consistent here
    kerns = [partial(matern_kernel_32, alpha=alpha, lengthscales=lscale,
                     jitter=JITTER) for
             alpha, lscale in zip(alphas, theta['lscales']**2)]
Esempio n. 5
0
def fit(
    X: np.ndarray,
    y: np.ndarray,
    n_inducing: int = 100,
    kernel: str = "matern_3/2",
    # Gamma priors (note tfp uses "concentration rate" parameterisation):
    kernel_variance_prior: Tuple[float, float] = (3 / 2, 3 / 2),
    kernel_lengthscale_prior: Tuple[float, float] = (3, 1 / 3),
    bias_variance_prior: Tuple[float, float] = (3 / 2, 3 / 2),
    random_seed: int = 2,
    verbose: bool = False,
) -> SOGPResult:

    np.random.seed(random_seed)

    assert kernel in [
        "matern_3/2",
        "matern_1/2",
        "rbf",
    ], "Only these three kernels are currently supported!"

    # Note that input _must_ be scaled. Some way to enforce that?

    kernel_fun = kern_lookup[kernel]

    n_cov = X.shape[1]

    # Set initial values
    start_alpha = np.array(1.0, dtype=np.float32)
    start_lengthscales = np.random.uniform(2.0, 4.0,
                                           size=n_cov).astype(np.float32)
    start_bias_sd = np.array(1.0, dtype=np.float32)

    Z = find_starting_z(X, n_inducing).astype(np.float32)

    start_kernel_fun = get_kernel_fun(kernel_fun, start_alpha,
                                      start_lengthscales, start_bias_sd)

    init_L = get_initial_values_from_kernel(Z, start_kernel_fun)
    init_mu = np.zeros(n_inducing, dtype=np.float32)

    init_theta = {
        "L_elts": init_L,
        "mu": init_mu,
        "alpha": start_alpha,
        "lscales": np.sqrt(start_lengthscales),
        "Z": Z,
        "bias_sd": start_bias_sd,
    }

    flat_theta, summary = flatten_and_summarise_tf(**init_theta)

    X = tf.constant(X.astype(np.float32))
    y = tf.constant(y.astype(np.float32))

    lscale_prior = tfp.distributions.Gamma(*kernel_lengthscale_prior)
    kernel_var_prior = tfp.distributions.Gamma(*kernel_variance_prior)
    bias_var_prior = tfp.distributions.Gamma(*bias_variance_prior)

    def to_minimize_with_grad(x):

        with tf.GradientTape() as tape:

            x_tf = tf.constant(x)
            x_tf = tf.cast(x_tf, tf.float32)

            tape.watch(x_tf)

            theta = reconstruct_tf(x_tf, summary)

            alpha, lscales, bias_sd = (
                theta["alpha"]**2,
                theta["lscales"]**2,
                theta["bias_sd"]**2,
            )

            L_cov = lo_tri_from_elements(theta["L_elts"], n_inducing)

            kern_fun = get_kernel_fun(kernel_fun, alpha, lscales, bias_sd)

            objective = -compute_objective(X, y, theta["mu"], L_cov,
                                           theta["Z"], bernoulli_probit_lik,
                                           kern_fun)

            objective = objective - (tf.reduce_sum(
                lscale_prior.log_prob(lscales)) + kernel_var_prior.log_prob(
                    alpha**2) + bias_var_prior.log_prob(bias_sd**2))

            grad = tape.gradient(objective, x_tf)

        if verbose:
            print(objective, np.linalg.norm(grad.numpy()))

        return (objective.numpy().astype(np.float64),
                grad.numpy().astype(np.float64))

    result = minimize(to_minimize_with_grad,
                      flat_theta,
                      jac=True,
                      method="L-BFGS-B")

    final_theta = reconstruct_tf(result.x, summary)
    final_theta = {
        x: y.numpy().astype(np.float32)
        for x, y in final_theta.items()
    }

    # Build the results
    fit_result = SOGPResult(
        L_elts=final_theta["L_elts"],
        mu=final_theta["mu"],
        kernel=kernel,
        lengthscales=final_theta["lscales"]**2,
        alpha=final_theta["alpha"]**2,
        bias_sd=final_theta["bias_sd"]**2,
        Z=final_theta["Z"],
    )

    return fit_result
elts_prior_return = init_elts.copy()

mean_surface_skills_serve = tf.zeros((n_players, n_surfaces))
mean_surface_skills_return = tf.zeros((n_players, n_surfaces))

init_theta = {
        'elts_serve': elts_serve,
        'elts_return': elts_return,
        'elts_prior_serve': elts_prior_serve,
        'elts_prior_return': elts_prior_return,
        'mean_surface_skills_serve': mean_surface_skills_serve,
        'mean_surface_skills_return': mean_surface_skills_return,
        'intercept': tf.constant(0.5)
}

flat_theta, summary = flatten_and_summarise_tf(**init_theta)


def to_optimise(flat_theta):

    flat_theta = tf.cast(tf.constant(flat_theta), tf.float32)

    with tf.GradientTape() as tape:

        tape.watch(flat_theta)

        theta = reconstruct_tf(flat_theta, summary)

        obj = -compute_objective(n=n, p=p, n_surfaces=n_surfaces,
                                 server_ids=server_ids,
                                 returner_ids=returner_ids, surf_ids=surf_ids,
Esempio n. 7
0
def fit_minibatching(
    X: np.ndarray,
    z: np.ndarray,
    weights: np.ndarray,
    n_inducing: int,
    thinning_indices: Optional[np.ndarray] = np.array([]),
    fit_inducing_using_presences_only: bool = False,
    verbose: bool = True,
    log_theta_dir: Optional[str] = None,
    use_berman_turner: bool = False,
    batch_size: int = 1000,
    learning_rate: float = 0.01,
    n_steps: int = 1000,
    sqrt_decay_learning_rate: bool = True,
):

    global STEP
    STEP = 0

    makedirs(log_theta_dir, exist_ok=True)

    n_cov = X.shape[1]

    if fit_inducing_using_presences_only:
        X_to_cluster = X[z > 0, :]
    else:
        X_to_cluster = X

    init_Z = find_starting_z(X_to_cluster, n_inducing).astype(np.float32)

    start_theta, init_kernel_spec = initialise_theta(n_cov, thinning_indices,
                                                     init_Z)

    flat_theta, summary = flatten_and_summarise_tf(**start_theta)

    data_dict = {"X": X, "z": z, "weights": weights}

    data_dict = {x: y.astype(np.float32) for x, y in data_dict.items()}

    if sqrt_decay_learning_rate:
        # Decay with sqrt of time
        step_size_fun = lambda t: learning_rate * (1 / np.sqrt(t))  # NOQA
    else:
        # Constant learning rate
        step_size_fun = lambda t: learning_rate  # NOQA

    opt_fun = partial(
        to_optimise,
        use_berman_turner=use_berman_turner,
        summary=summary,
        init_kernel_spec=init_kernel_spec,
        log_theta_dir=log_theta_dir,
        verbose=verbose,
        likelihood_scale_factor=X.shape[0] / batch_size,
    )

    adam_state = initialise_state(flat_theta.shape[0])

    adam_fun = partial(adam_step, step_size_fun=step_size_fun)

    result, loss_log = optimise_minibatching(
        data_dict,
        opt_fun,
        adam_fun,
        flat_theta,
        batch_size,
        n_steps,
        X.shape[0],
        join(log_theta_dir, "loss.txt"),
        False,
        adam_state,
    )

    final_flat_theta = result.numpy().astype(np.float32)
    final_theta = reconstruct_tf(final_flat_theta, summary)
    _, final_spec = update_specs(final_theta, init_kernel_spec)

    return final_spec
Esempio n. 8
0
def fit(
    X: np.ndarray,
    z: np.ndarray,
    weights: np.ndarray,
    n_inducing: int,
    thinning_indices: Optional[np.ndarray] = np.array([]),
    fit_inducing_using_presences_only: bool = False,
    verbose: bool = True,
    log_theta_dir: Optional[str] = None,
    use_berman_turner: bool = True,
    test_run: bool = False,
):

    global STEP
    STEP = 0

    n_cov = X.shape[1]

    if fit_inducing_using_presences_only:
        X_to_cluster = X[z > 0, :]
    else:
        X_to_cluster = X

    init_Z = find_starting_z(X_to_cluster, n_inducing).astype(np.float32)

    start_theta, init_kernel_spec = initialise_theta(n_cov, thinning_indices,
                                                     init_Z)

    # Prepare the tensors
    X = tf.cast(tf.constant(X), tf.float32)
    z = tf.cast(tf.constant(z), tf.float32)
    weights = tf.cast(tf.constant(weights), tf.float32)

    flat_theta, summary = flatten_and_summarise_tf(**start_theta)

    opt_fun = partial(
        to_optimise,
        X=X,
        z=z,
        weights=weights,
        use_berman_turner=use_berman_turner,
        summary=summary,
        init_kernel_spec=init_kernel_spec,
        log_theta_dir=log_theta_dir,
        verbose=verbose,
        likelihood_scale_factor=1.0,
    )

    if test_run:
        result = minimize(
            opt_fun,
            flat_theta.numpy().astype(np.float64),
            method="L-BFGS-B",
            jac=True,
            tol=1,
        )
    else:
        result = minimize(opt_fun,
                          flat_theta.numpy().astype(np.float64),
                          method="L-BFGS-B",
                          jac=True)

    final_flat_theta = result.x.astype(np.float32)
    final_theta = reconstruct_tf(final_flat_theta, summary)
    _, final_spec = update_specs(final_theta, init_kernel_spec)

    return final_spec