def __init__(self, state, bparam, state_0, bparam_0, counter, objective,
                 accuracy_fn, hparams):
        self._state_wrap = StateVariable(state, counter)
        self._bparam_wrap = StateVariable(bparam, counter)
        self._prev_state = state_0
        self._prev_bparam = bparam_0
        self.objective = objective
        self.accuracy_fn = accuracy_fn
        self.value_func = jit(self.objective)
        self._value_wrap = StateVariable(0.005, counter)
        self._quality_wrap = StateVariable(0.005, counter)
        self.sw = None
        self.hparams = hparams
        self.opt = OptimizerCreator(
            opt_string=hparams["meta"]["optimizer"],
            learning_rate=hparams["natural_lr"]).get_optimizer()
        if hparams["meta"]["dataset"] == "mnist":
            if hparams["continuation_config"] == 'data':
                self.dataset_tuple = mnist_gamma(
                    resize=hparams["resize_to_small"],
                    filter=hparams["filter"])
            else:
                self.dataset_tuple = mnist(resize=hparams["resize_to_small"],
                                           filter=hparams["filter"])
        self.continuation_steps = hparams["continuation_steps"]

        self.output_file = hparams["meta"]["output_dir"]
        self._delta_s = hparams["delta_s"]
        self._prev_delta_s = hparams["delta_s"]
        self._omega = hparams["omega"]
        self.grad_fn = jit(grad(self.objective, argnums=[0]))
        self.prev_secant_direction = None
    def __init__(self, state, bparam, counter, objective, accuracy_fn,
                 hparams):
        self._state_wrap = StateVariable(state, counter)
        self._bparam_wrap = StateVariable(bparam, counter)
        self.objective = objective
        self.value_func = jit(self.objective)
        self.accuracy_fn = jit(accuracy_fn)
        self._value_wrap = StateVariable(2.0, counter)
        self._quality_wrap = StateVariable(0.25, counter)
        self.sw = None
        self.hparams = hparams
        if hparams["meta"]["dataset"] == "mnist":
            if hparams["continuation_config"] == 'data':
                self.dataset_tuple = mnist_gamma(
                    resize=hparams["resize_to_small"],
                    filter=hparams["filter"])
            else:
                print("model continuation")
                self.dataset_tuple = mnist(resize=hparams["resize_to_small"],
                                           filter=hparams["filter"])
        self.continuation_steps = hparams["continuation_steps"]

        self.output_file = hparams["meta"]["output_dir"]
        self._delta_s = hparams["delta_bparams"]
        self.grad_fn = jit(
            grad(self.objective,
                 argnums=[0]))  # TODO: vmap is not fully supported with stax
Example #3
0
import jax.numpy as np
from jax import random
from jax.experimental.optimizers import l2_norm
from jax.tree_util import *
from flax import linen as nn  # The Linen API
import jax
import numpy.random as npr
from cjax.utils import datasets

#
# num_classes = 10
# inputs = random.normal(random.PRNGKey(1), (1, 10, 10))
# outputs = np.ones(shape=(num_classes, 10))

batch_size = 20
train_images, train_labels, test_images, test_labels = datasets.mnist()
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)


def data_stream():
    rng = npr.RandomState(0)
    while True:
        perm = rng.permutation(num_train)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield train_images[batch_idx], train_labels[batch_idx]


batches = data_stream()
Example #4
0
                {
                    "train_loss": float(loss),
                    "ma_loss": float(ma_loss[-1]),
                    "learning_rate": float(opt.lr),
                    "norm grads": float(l2_norm(ae_grads))
                }, epoch)

            if len(ma_loss) > 100:
                loss_check = running_mean(ma_loss, 50)
                if math.isclose(loss_check[-1],
                                loss_check[-2],
                                abs_tol=hparams["loss_tol"]):
                    print(f"stopping at {epoch}")
                    break

        train_images, train_labels, test_images, test_labels = mnist(
            permute_train=False, resize=True, filter=hparams["filter"])

        val_loss = problem.objective(ae_params, (test_images, test_labels))
        print(f"val loss: {val_loss, type(ae_params)}")
        val_acc = accuracy(ae_params, (test_images, test_labels))
        print(f"val acc: {val_acc}")
        mlflow.log_metric("val_acc", float(val_acc))
        mlflow.log_metric("val_loss", float(val_loss))

        q = float(l2_norm(ae_grads[0]))
        if sw:
            sw.write([
                {
                    'u': ae_params
                },
                {
Example #5
0
    with open(problem.HPARAMS_PATH, "r") as hfile:
        hparams = json.load(hfile)
    opt = AdamOptimizer(learning_rate=hparams["descent_lr"])
    ma_loss = []
    for epoch in range(500):
        for b_j in range(num_batches):
            batch = next(data_loader)
            grads = compute_grad_fn(ae_params, bparam, batch)
            ae_params = opt.update_params(ae_params,
                                          grads[0],
                                          step_index=epoch)
            loss = problem.objective(ae_params, bparam, batch)
            ma_loss.append(loss)
            print(f"loss:{loss}  norm:{l2_norm(grads)}")
        opt.lr = exp_decay(epoch, hparams["descent_lr"])
        if len(ma_loss) > 40:
            loss_check = running_mean(ma_loss, 30)
            if math.isclose(loss_check[-1],
                            loss_check[-2],
                            abs_tol=hparams["loss_tol"]):
                print(f"stopping at {epoch}")
                break

    train_images, train_labels, test_images, test_labels = mnist(
        permute_train=False, resize=True)
    val_loss = problem.objective(ae_params, bparam, (test_images, test_labels))
    print(f"val loss: {val_loss}")

    # init_c = constant_2d(I)
    # print(init_c(key=0, shape=(8,8)))
Example #6
0
                    "ma_loss": float(ma_loss[-1]),
                    "learning_rate": float(opt.lr),
                    "bparam": float(bparam[0]),
                    "norm grads": float(l2_norm(ae_grads))
                }, epoch)

            if len(ma_loss) > 100:
                loss_check = running_mean(ma_loss, 50)
                if math.isclose(loss_check[-1],
                                loss_check[-2],
                                abs_tol=hparams["loss_tol"]):
                    print(f"stopping at {epoch}")
                    break

        train_images, train_labels, test_images, test_labels = mnist(
            permute_train=False,
            resize=hparams["resize_to_small"],
            filter=hparams["filter"])

        val_loss = problem.objective(ae_params, bparam,
                                     (test_images, test_labels))
        print(f"val loss: {val_loss, type(ae_params)}")
        val_acc = accuracy(ae_params, bparam, (test_images, test_labels))
        print(f"val acc: {val_acc}")
        mlflow.log_metric("val_acc", float(val_acc))
        mlflow.log_metric("val_loss", float(val_loss))

        q = float(l2_norm(ae_grads[0]))
        if sw:
            sw.write([
                {
                    'u': ae_params
    def correction_step(self) -> Tuple:
        """Given the current state optimize to the correct state.

        Returns:
          (state: problem parameters, bparam: continuation parameter) Tuple
        """

        quality = 1.0
        if self.hparams["meta"]["dataset"] == "mnist":  # TODO: make it generic
            batch_data = next(self.data_loader)
        else:
            batch_data = None

        ants_norm_grads = [5.0 for _ in range(self.hparams["n_wall_ants"])]
        ants_loss_values = [5.0 for _ in range(self.hparams["n_wall_ants"])]
        ants_state = [self._state for _ in range(self.hparams["n_wall_ants"])]
        ants_bparam = [
            self._bparam for _ in range(self.hparams["n_wall_ants"])
        ]
        for i_n in range(self.hparams["n_wall_ants"]):
            corrector_omega = 1.0
            stop = False
            _, key = random.split(
                random.PRNGKey(self.key_state + i_n +
                               npr.randint(1, (i_n + 1) * 10)))
            del _
            self._parc_vec, self.state_stack = self._perform_perturb_by_projection(
                self._state_secant_vector,
                self._state_secant_c2,
                key,
                self.pred_prev_state,
                self._state,
                self._bparam,
                i_n,
                self.sphere_radius,
                batch_data,
            )
            if self.hparams["_evaluate_perturb"]:
                self._evaluate_perturb()  # does every time

            ants_state[i_n] = self.state_stack["state"]
            ants_bparam[i_n] = self.state_stack["bparam"]
            D_values = []
            print(f"num_batches", self.num_batches)
            for j_epoch in range(self.descent_period):
                for b_j in range(self.num_batches):

                    #alternate
                    # grads = self.compute_grad_fn(self._state, self._bparam, batch_data)
                    # self._state = self.opt.update_params(self._state, grads[0])
                    state_grads, bparam_grads = self.compute_min_grad_fn(
                        ants_state[i_n],
                        ants_bparam[i_n],
                        self._lagrange_multiplier,
                        self._state_secant_c2,
                        self._state_secant_vector,
                        batch_data,
                        self.delta_s,
                    )

                    if self.hparams["adaptive"]:
                        self.opt.lr = self.exp_decay(
                            j_epoch, self.hparams["natural_lr"])
                        quality = l2_norm(state_grads)  #l2_norm(bparam_grads)
                        if self.hparams[
                                "local_test_measure"] == "norm_gradients":
                            if quality > self.hparams["quality_thresh"]:
                                pass
                                print(
                                    f"quality {quality}, {self.opt.lr}, {bparam_grads} ,{j_epoch}"
                                )
                            else:
                                stop = True
                                print(
                                    f"quality {quality} stopping at , {j_epoch}th step"
                                )
                        else:
                            print(
                                f"quality {quality}, {bparam_grads} ,{j_epoch}"
                            )
                            if len(D_values) >= 20:
                                tmp_means = running_mean(D_values, 10)
                                if (math.isclose(
                                        tmp_means[-1],
                                        tmp_means[-2],
                                        abs_tol=self.hparams["loss_tol"])):
                                    print(
                                        f"stopping at , {j_epoch}th step, {ants_bparam[i_n]} bparam"
                                    )
                                    stop = True

                        state_grads = clip_grads(state_grads,
                                                 self.hparams["max_clip_grad"])
                        bparam_grads = clip_grads(
                            bparam_grads, self.hparams["max_clip_grad"])

                    if self.hparams["guess_ant_steps"] >= (
                            j_epoch + 1):  # To get around folds slowly
                        corrector_omega = min(
                            self.hparams["guess_ant_steps"] / (j_epoch + 1),
                            1.5)
                    else:
                        corrector_omega = max(
                            self.hparams["guess_ant_steps"] / (j_epoch + 1),
                            0.05)

                    ants_state[i_n] = self.opt.update_params(
                        ants_state[i_n], state_grads, j_epoch)
                    ants_bparam[i_n] = self.opt.update_params(
                        ants_bparam[i_n], bparam_grads, j_epoch)
                    ants_loss_values[i_n] = self.value_fn(
                        ants_state[i_n], ants_bparam[i_n], batch_data)
                    D_values.append(ants_loss_values[i_n])
                    ants_norm_grads[i_n] = quality
                    # if stop:
                    #     break
                    if (self.hparams["meta"]["dataset"] == "mnist"
                        ):  # TODO: make it generic
                        batch_data = next(self.data_loader)
                if stop:
                    break

        # ants_group = dict(enumerate(grouper(ants_state, tolerence), 1))
        # print(f"Number of groups: {len(ants_group)}")
        cheapest_index = get_cheapest_ant(
            ants_norm_grads,
            ants_loss_values,
            local_test=self.hparams["local_test_measure"])
        self._state = ants_state[cheapest_index]
        self._bparam = ants_bparam[cheapest_index]
        value = self.value_fn(self._state, self._bparam,
                              batch_data)  # Todo: why only final batch data

        _, _, test_images, test_labels = mnist(permute_train=False,
                                               resize=True,
                                               filter=self.hparams["filter"])
        del _
        val_loss = self.value_fn(self._state, self._bparam,
                                 (test_images, test_labels))
        print(f"val loss: {val_loss}")

        return self._state, self._bparam, quality, value, val_loss, corrector_omega
Example #8
0
    def __init__(
        self,
        state,
        bparam,
        state_0,
        bparam_0,
        counter,
        objective,
        dual_objective,
        accuracy_fn,
        hparams,
        key_state,
    ):

        # states
        self._state_wrap = StateVariable(state, counter)
        self._bparam_wrap = StateVariable(
            bparam, counter
        )  # Todo : save tree def, always unlfatten before compute_grads
        self._prev_state = state_0
        self._prev_bparam = bparam_0

        # objectives
        self.objective = objective
        self.dual_objective = dual_objective
        self.accuracy_fn1 = jit(accuracy_fn)
        self.value_func = jit(self.objective)

        self.hparams = hparams
        if hparams["meta"]["dataset"] == "mnist":
            if hparams["continuation_config"] == 'data':
                self.dataset_tuple = mnist_gamma(
                    resize=hparams["resize_to_small"],
                    filter=hparams["filter"])
            else:
                self.dataset_tuple = mnist(resize=hparams["resize_to_small"],
                                           filter=hparams["filter"])

        self._value_wrap = StateVariable(
            0.06, counter)  # TODO: fix with a static batch (test/train)
        self._quality_wrap = StateVariable(
            l2_norm(self._state_wrap.state) / 10, counter)

        # every step hparams
        self.continuation_steps = hparams["continuation_steps"]

        self._delta_s = hparams["delta_s"]
        self._prev_delta_s = hparams["delta_s"]
        self._omega = hparams["omega"]

        # grad functions # should be pure functional
        self.compute_min_grad_fn = jit(grad(self.dual_objective, [0, 1]))
        self.compute_grad_fn = jit(grad(self.objective, [0]))

        # extras
        self.state_tree_def = None
        self.bparam_tree_def = None
        self.output_file = hparams["meta"]["output_dir"]
        self.prev_secant_direction = None
        self.perturb_index = key_state
        self.sw = StateWriter(
            f"{self.output_file}/version_{self.perturb_index}.json")
        self.key_state = key_state + npr.randint(100, 200)
        self.clip_lambda_max = lambda g: np.where(
            (g > self.hparams["lambda_max"]), self.hparams["lambda_max"], g)
        self.clip_lambda_min = lambda g: np.where(
            (g < self.hparams["lambda_min"]), self.hparams["lambda_min"], g)