def __init__(self, objective, dual_objective, accuracy_fn1, value_fn,
                 concat_states, key_state, compute_min_grad_fn,
                 compute_grad_fn, hparams, delta_s, pred_state,
                 pred_prev_state, counter, dataset_tuple):
        self.concat_states = concat_states
        self._state = None
        self._bparam = None
        self.opt = OptimizerCreator(
            opt_string=hparams["meta"]["optimizer"],
            learning_rate=hparams["descent_lr"]).get_optimizer()
        self.objective = objective
        self.dual_objective = dual_objective
        self._lagrange_multiplier = hparams["lagrange_init"]
        self._state_secant_vector = None
        self._state_secant_c2 = None
        self.delta_s = delta_s
        self.descent_period = hparams["descent_period"]
        self.max_norm_state = hparams["max_bounds"]
        self.hparams = hparams
        self.compute_min_grad_fn = compute_min_grad_fn
        self.compute_grad_fn = compute_grad_fn
        self._assign_states()
        self._parc_vec = None
        self.state_stack = dict()
        self.key_state = key_state
        self.pred_state = pred_state
        self.pred_prev_state = pred_prev_state
        self.sphere_radius = hparams["sphere_radius"]
        self.counter = counter
        self.value_fn = value_fn
        self.accuracy_fn1 = accuracy_fn1
        self.dataset_tuple = dataset_tuple
        if hparams["meta"]["dataset"] == "mnist":
            (self.train_images, self.train_labels, self.test_images,
             self.test_labels) = dataset_tuple

            if hparams["continuation_config"] == 'data':
                # data continuation
                self.data_loader = iter(
                    get_mnist_batch_alter(self.train_images,
                                          self.train_labels,
                                          self.test_images,
                                          self.test_labels,
                                          alter=self._bparam,
                                          batch_size=hparams["batch_size"],
                                          resize=hparams["resize_to_small"],
                                          filter=hparams["filter"]))
            else:
                # model continuation
                self.data_loader = iter(
                    get_mnist_data(batch_size=hparams["batch_size"],
                                   resize=hparams["resize_to_small"],
                                   filter=hparams["filter"]))
            self.num_batches = meta_mnist(hparams["batch_size"],
                                          hparams["filter"])["num_batches"]
        else:
            self.data_loader = None
            self.num_batches = 1
示例#2
0
    def __init__(
        self,
        state,
        bparam,
        state_0,
        bparam_0,
        counter,
        objective,
        dual_objective,
        hparams,
    ):
        # states
        self._state_wrap = StateVariable(state, counter)
        self._bparam_wrap = StateVariable(
            bparam, counter
        )  # Todo : save tree def, always unlfatten before compute_grads
        self._prev_state = state_0
        self._prev_bparam = bparam_0

        # objectives
        self.objective = objective
        self.dual_objective = dual_objective
        self.value_func = jit(self.objective)

        self.hparams = hparams

        self._value_wrap = StateVariable(
            1.0, counter)  # TODO: fix with a static batch (test/train)
        self._quality_wrap = StateVariable(l2_norm(self._state_wrap.state),
                                           counter)

        # optimizer
        self.opt = OptimizerCreator(
            opt_string=hparams["meta"]["optimizer"],
            learning_rate=hparams["natural_lr"]).get_optimizer()
        self.ascent_opt = OptimizerCreator(
            opt_string=hparams["meta"]["ascent_optimizer"],
            learning_rate=hparams["ascent_lr"],
        ).get_optimizer()

        # every step hparams
        self.continuation_steps = hparams["continuation_steps"]
        self._lagrange_multiplier = hparams["lagrange_init"]

        self._delta_s = hparams["delta_s"]
        self._omega = hparams["omega"]

        # grad functions # should be pure functional
        self.compute_min_grad_fn = jit(grad(self.dual_objective, [0, 1]))
        self.compute_max_grad_fn = jit(grad(self.dual_objective, [2]))
        self.compute_grad_fn = jit(grad(self.objective, [0]))

        # extras
        self.sw = None
        self.state_tree_def = None
        self.bparam_tree_def = None
        self.output_file = hparams["meta"]["output_dir"]
        self.prev_secant_direction = None
示例#3
0
 def __init__(
     self,
     objective,
     dual_objective,
     value_fn,
     concat_states,
     key_state,
     compute_min_grad_fn,
     compute_grad_fn,
     hparams,
     pred_state,
     pred_prev_state,
     counter,
 ):
     self.concat_states = concat_states
     self._state = None
     self._bparam = None
     self.opt = OptimizerCreator(
         opt_string=hparams["meta"]["optimizer"], learning_rate=hparams["descent_lr"]
     ).get_optimizer()
     self.objective = objective
     self.dual_objective = dual_objective
     self._lagrange_multiplier = hparams["lagrange_init"]
     self._state_secant_vector = None
     self._state_secant_c2 = None
     self.delta_s = hparams["delta_s"]
     self.descent_period = hparams["descent_period"]
     self.max_norm_state = hparams["max_bounds"]
     self.hparams = hparams
     self.compute_min_grad_fn = compute_min_grad_fn
     self.compute_grad_fn = compute_grad_fn
     self._assign_states()
     self._parc_vec = None
     self.state_stack = dict()
     self.key_state = key_state
     self.pred_state = pred_state
     self.pred_prev_state = pred_prev_state
     self.sphere_radius = hparams["sphere_radius"]
     self.counter = counter
     self.value_fn = value_fn
     # self.data_loader = iter(get_data(dataset=hparams["meta"]['dataset'],
     #                             batch_size=hparams['batch_size'],
     #                             num_workers=hparams['data_workers'],
     #                             train_only=True, test_only=False))
     if hparams["meta"]["dataset"] == "mnist":
         self.data_loader = iter(
             get_mnist_data(
                 batch_size=hparams["batch_size"], resize=hparams["resize_to_small"]
             )
         )
         self.num_batches = meta_mnist(hparams["batch_size"])["num_batches"]
     else:
         self.data_loader = None
         self.num_batches = 1
    def __init__(self, objective, concat_states, grad_fn, value_fn,
                 accuracy_fn, hparams, dataset_tuple):
        self.concat_states = concat_states
        self._state = None
        self._bparam = None
        self.opt = OptimizerCreator(
            opt_string=hparams["meta"]["optimizer"],
            learning_rate=hparams["natural_lr"]).get_optimizer()
        self.objective = objective
        self.accuracy_fn = accuracy_fn
        self.warmup_period = hparams["warmup_period"]
        self.hparams = hparams
        self.grad_fn = grad_fn
        self.value_fn = value_fn
        self._assign_states()
        if hparams["meta"]["dataset"] == "mnist":
            (self.train_images, self.train_labels, self.test_images,
             self.test_labels) = dataset_tuple
            if hparams["continuation_config"] == 'data':
                # data continuation
                self.data_loader = iter(
                    get_mnist_batch_alter(self.train_images,
                                          self.train_labels,
                                          self.test_images,
                                          self.test_labels,
                                          alter=self._bparam,
                                          batch_size=hparams["batch_size"],
                                          resize=hparams["resize_to_small"],
                                          filter=hparams["filter"]))
            else:
                # model continuation
                self.data_loader = iter(
                    get_mnist_data(batch_size=hparams["batch_size"],
                                   resize=hparams["resize_to_small"],
                                   filter=hparams["filter"])

                    # get_preload_mnist_data(self.train_images, ## TODO: better way to prefetch mnist
                    #                        self.train_labels,
                    #                        self.test_images,
                    #                        self.test_labels,
                    #                          batch_size = hparams["batch_size"],
                    #                          resize = hparams["resize_to_small"],
                    #                         filter = hparams["filter"])
                )

            self.num_batches = meta_mnist(hparams["batch_size"],
                                          hparams["filter"])["num_batches"]
        else:
            self.data_loader = None
            self.num_batches = 1
    def __init__(self, state, bparam, state_0, bparam_0, counter, objective,
                 accuracy_fn, hparams):
        self._state_wrap = StateVariable(state, counter)
        self._bparam_wrap = StateVariable(bparam, counter)
        self._prev_state = state_0
        self._prev_bparam = bparam_0
        self.objective = objective
        self.accuracy_fn = accuracy_fn
        self.value_func = jit(self.objective)
        self._value_wrap = StateVariable(0.005, counter)
        self._quality_wrap = StateVariable(0.005, counter)
        self.sw = None
        self.hparams = hparams
        self.opt = OptimizerCreator(
            opt_string=hparams["meta"]["optimizer"],
            learning_rate=hparams["natural_lr"]).get_optimizer()
        if hparams["meta"]["dataset"] == "mnist":
            if hparams["continuation_config"] == 'data':
                self.dataset_tuple = mnist_gamma(
                    resize=hparams["resize_to_small"],
                    filter=hparams["filter"])
            else:
                self.dataset_tuple = mnist(resize=hparams["resize_to_small"],
                                           filter=hparams["filter"])
        self.continuation_steps = hparams["continuation_steps"]

        self.output_file = hparams["meta"]["output_dir"]
        self._delta_s = hparams["delta_s"]
        self._prev_delta_s = hparams["delta_s"]
        self._omega = hparams["omega"]
        self.grad_fn = jit(grad(self.objective, argnums=[0]))
        self.prev_secant_direction = None
示例#6
0
        hparams["meta"]["output_dir"] = artifact_uri2
        file_name = f"{artifact_uri2}/version.jsonl"

        sw = StateWriter(file_name=file_name)

        data_loader = iter(
            get_mnist_data(batch_size=hparams["batch_size"],
                           resize=True,
                           filter=hparams['filter']))
        num_batches = meta_mnist(batch_size=hparams["batch_size"],
                                 filter=hparams['filter'])["num_batches"]
        print(f"num of bathces: {num_batches}")
        compute_grad_fn = jit(grad(problem.objective, [0]))

        opt = OptimizerCreator(
            hparams["meta"]["optimizer"],
            learning_rate=hparams["natural_lr"]).get_optimizer()
        ma_loss = []
        for epoch in range(hparams["warmup_period"]):
            for b_j in range(num_batches):
                batch = next(data_loader)
                ae_grads = compute_grad_fn(ae_params, batch)
                ae_params = opt.update_params(ae_params,
                                              ae_grads[0],
                                              step_index=epoch)
                loss = problem.objective(ae_params, batch)
                ma_loss.append(loss)
                print(f"loss:{loss}  norm:{l2_norm(ae_grads)}")
            #opt.lr = exp_decay(epoch, hparams["natural_lr"])
            mlflow.log_metrics(
                {
示例#7
0
        hparams["meta"]["output_dir"] = artifact_uri2
        file_name = f"{artifact_uri2}/version.jsonl"

        sw = StateWriter(file_name=file_name)

        data_loader = iter(
            get_mnist_data(batch_size=hparams["batch_size"],
                           resize=True,
                           filter=hparams['filter']))
        num_batches = meta_mnist(batch_size=hparams["batch_size"],
                                 filter=hparams['filter'])["num_batches"]
        print(f"num of bathces: {num_batches}")
        compute_grad_fn = jit(grad(problem.objective, [0, 1]))

        opt = OptimizerCreator(
            hparams["meta"]["optimizer"],
            learning_rate=hparams["natural_lr"]).get_optimizer()
        ma_loss = []
        for epoch in range(hparams["warmup_period"]):
            for b_j in range(num_batches):
                batch = next(data_loader)
                ae_grads, b_grads = compute_grad_fn(ae_params, bparam, batch)
                grads = ae_grads
                ae_params = opt.update_params(ae_params,
                                              ae_grads,
                                              step_index=epoch)
                bparam = opt.update_params(bparam, b_grads, step_index=epoch)
                loss = problem.objective(ae_params, bparam, batch)
                ma_loss.append(loss)
                print(f"loss:{loss}  norm:{l2_norm(grads)}")
            opt.lr = exp_decay(epoch, hparams["natural_lr"])
class PerturbedFixedCorrecter(Corrector):
    """Minimize the objective using gradient based method along with some constraint and noise"""
    def __init__(
        self,
        objective,
        dual_objective,
        value_fn,
        concat_states,
        key_state,
        compute_min_grad_fn,
        compute_grad_fn,
        hparams,
        delta_s,
        pred_state,
        pred_prev_state,
        counter,
    ):
        self.concat_states = concat_states
        self._state = None
        self._bparam = None
        self.opt = OptimizerCreator(
            opt_string=hparams["meta"]["optimizer"],
            learning_rate=hparams["descent_lr"]).get_optimizer()
        self.objective = objective
        self.dual_objective = dual_objective
        self._lagrange_multiplier = hparams["lagrange_init"]
        self._state_secant_vector = None
        self._state_secant_c2 = None
        self.delta_s = delta_s
        self.descent_period = hparams["descent_period"]
        self.max_norm_state = hparams["max_bounds"]
        self.hparams = hparams
        self.compute_min_grad_fn = compute_min_grad_fn
        self.compute_grad_fn = compute_grad_fn
        self._assign_states()
        self._parc_vec = None
        self.state_stack = dict()
        self.key_state = key_state
        self.pred_state = pred_state
        self.pred_prev_state = pred_prev_state
        self.sphere_radius = hparams["sphere_radius"]
        self.counter = counter
        self.value_fn = value_fn
        if hparams["meta"]["dataset"] == "mnist":
            self.data_loader = iter(
                get_mnist_data(batch_size=hparams["batch_size"],
                               resize=hparams["resize_to_small"],
                               filter=hparams["filter"]))
            self.num_batches = meta_mnist(hparams["batch_size"],
                                          hparams["filter"])["num_batches"]
        else:
            self.num_batches = 1

    def _assign_states(self):
        self._state = self.concat_states[0]
        self._bparam = self.concat_states[1]
        self._state_secant_vector = self.concat_states[2]
        self._state_secant_c2 = self.concat_states[3]

    @staticmethod
    @jit
    def exp_decay(epoch, initial_lrate):
        k = 0.02
        lrate = initial_lrate * np.exp(-k * epoch)
        return lrate

    @staticmethod
    @jit
    def _perform_perturb_by_projection(
        _state_secant_vector,
        _state_secant_c2,
        key,
        pred_prev_state,
        _state,
        _bparam,
        counter,
        sphere_radius,
        batch_data,
    ):
        ### Secant normal
        n, sample_unravel = pytree_to_vec(
            [_state_secant_vector["state"], _state_secant_vector["bparam"]])
        n = pytree_normalized(n)
        ### sample a random poin in Rn
        # u = tree_map(
        #     lambda a: a + random.uniform(key, a.shape),
        #     pytree_zeros_like(n),
        # )
        u = tree_map(
            lambda a: a + random.normal(key, a.shape),
            pytree_ones_like(n),
        )
        tmp, _ = pytree_to_vec(
            [_state_secant_c2["state"], _state_secant_c2["bparam"]])

        # select a point on the secant normal
        u_0, _ = pytree_to_vec(pred_prev_state)
        # compute projection
        proj_of_u_on_n = projection_affine(len(n), u, n, u_0)

        point_on_plane = u + pytree_sub(
            tmp, proj_of_u_on_n)  ## state= pred_state + n
        #noise = random.uniform(key, [1], minval=-0.003, maxval=0.03)
        inv_vec = np.array([-1.0, 1.0])
        parc = pytree_element_mul(
            pytree_normalized(pytree_sub(point_on_plane, tmp)),
            inv_vec[(counter % 2)],
        )
        point_on_plane_2 = tmp + sphere_radius * parc
        new_sample = sample_unravel(point_on_plane_2)
        state_stack = {}
        state_stack.update({"state": new_sample[0]})
        state_stack.update({"bparam": new_sample[1]})
        _parc_vec = pytree_sub(state_stack, _state_secant_c2)
        return _parc_vec, state_stack

    def _evaluate_perturb(self):
        """Evaluate weather the perturbed vector is orthogonal to secant vector"""

        dot = pytree_dot(
            pytree_normalized(self._parc_vec),
            pytree_normalized(self._state_secant_vector),
        )
        if math.isclose(dot, 0.0, abs_tol=0.15):
            print(f"Perturb was near arc-plane. {dot}")
        else:
            print(f"Perturb was not on arc-plane.{dot}")

    def correction_step(self) -> Tuple:
        """Given the current state optimize to the correct state.

        Returns:
          (state: problem parameters, bparam: continuation parameter) Tuple
        """

        quality = 1.0
        if self.hparams["meta"]["dataset"] == "mnist":  # TODO: make it generic
            batch_data = next(self.data_loader)
        else:
            batch_data = None

        ants_norm_grads = [5.0 for _ in range(self.hparams["n_wall_ants"])]
        ants_loss_values = [5.0 for _ in range(self.hparams["n_wall_ants"])]
        ants_state = [self._state for _ in range(self.hparams["n_wall_ants"])]
        ants_bparam = [
            self._bparam for _ in range(self.hparams["n_wall_ants"])
        ]
        for i_n in range(self.hparams["n_wall_ants"]):
            corrector_omega = 1.0
            stop = False
            _, key = random.split(
                random.PRNGKey(self.key_state + i_n +
                               npr.randint(1, (i_n + 1) * 10)))
            del _
            self._parc_vec, self.state_stack = self._perform_perturb_by_projection(
                self._state_secant_vector,
                self._state_secant_c2,
                key,
                self.pred_prev_state,
                self._state,
                self._bparam,
                i_n,
                self.sphere_radius,
                batch_data,
            )
            if self.hparams["_evaluate_perturb"]:
                self._evaluate_perturb()  # does every time

            ants_state[i_n] = self.state_stack["state"]
            ants_bparam[i_n] = self.state_stack["bparam"]
            D_values = []
            print(f"num_batches", self.num_batches)
            for j_epoch in range(self.descent_period):
                for b_j in range(self.num_batches):

                    #alternate
                    # grads = self.compute_grad_fn(self._state, self._bparam, batch_data)
                    # self._state = self.opt.update_params(self._state, grads[0])
                    state_grads, bparam_grads = self.compute_min_grad_fn(
                        ants_state[i_n],
                        ants_bparam[i_n],
                        self._lagrange_multiplier,
                        self._state_secant_c2,
                        self._state_secant_vector,
                        batch_data,
                        self.delta_s,
                    )

                    if self.hparams["adaptive"]:
                        self.opt.lr = self.exp_decay(
                            j_epoch, self.hparams["natural_lr"])
                        quality = l2_norm(state_grads)  #l2_norm(bparam_grads)
                        if self.hparams[
                                "local_test_measure"] == "norm_gradients":
                            if quality > self.hparams["quality_thresh"]:
                                pass
                                print(
                                    f"quality {quality}, {self.opt.lr}, {bparam_grads} ,{j_epoch}"
                                )
                            else:
                                stop = True
                                print(
                                    f"quality {quality} stopping at , {j_epoch}th step"
                                )
                        else:
                            print(
                                f"quality {quality}, {bparam_grads} ,{j_epoch}"
                            )
                            if len(D_values) >= 20:
                                tmp_means = running_mean(D_values, 10)
                                if (math.isclose(
                                        tmp_means[-1],
                                        tmp_means[-2],
                                        abs_tol=self.hparams["loss_tol"])):
                                    print(
                                        f"stopping at , {j_epoch}th step, {ants_bparam[i_n]} bparam"
                                    )
                                    stop = True

                        state_grads = clip_grads(state_grads,
                                                 self.hparams["max_clip_grad"])
                        bparam_grads = clip_grads(
                            bparam_grads, self.hparams["max_clip_grad"])

                    if self.hparams["guess_ant_steps"] >= (
                            j_epoch + 1):  # To get around folds slowly
                        corrector_omega = min(
                            self.hparams["guess_ant_steps"] / (j_epoch + 1),
                            1.5)
                    else:
                        corrector_omega = max(
                            self.hparams["guess_ant_steps"] / (j_epoch + 1),
                            0.05)

                    ants_state[i_n] = self.opt.update_params(
                        ants_state[i_n], state_grads, j_epoch)
                    ants_bparam[i_n] = self.opt.update_params(
                        ants_bparam[i_n], bparam_grads, j_epoch)
                    ants_loss_values[i_n] = self.value_fn(
                        ants_state[i_n], ants_bparam[i_n], batch_data)
                    D_values.append(ants_loss_values[i_n])
                    ants_norm_grads[i_n] = quality
                    # if stop:
                    #     break
                    if (self.hparams["meta"]["dataset"] == "mnist"
                        ):  # TODO: make it generic
                        batch_data = next(self.data_loader)
                if stop:
                    break

        # ants_group = dict(enumerate(grouper(ants_state, tolerence), 1))
        # print(f"Number of groups: {len(ants_group)}")
        cheapest_index = get_cheapest_ant(
            ants_norm_grads,
            ants_loss_values,
            local_test=self.hparams["local_test_measure"])
        self._state = ants_state[cheapest_index]
        self._bparam = ants_bparam[cheapest_index]
        value = self.value_fn(self._state, self._bparam,
                              batch_data)  # Todo: why only final batch data

        _, _, test_images, test_labels = mnist(permute_train=False,
                                               resize=True,
                                               filter=self.hparams["filter"])
        del _
        val_loss = self.value_fn(self._state, self._bparam,
                                 (test_images, test_labels))
        print(f"val loss: {val_loss}")

        return self._state, self._bparam, quality, value, val_loss, corrector_omega
示例#9
0
class PerturbedFixedCorrecter(Corrector):
    """Minimize the objective using gradient based method along with some constraint and noise"""

    def __init__(
        self,
        objective,
        dual_objective,
        value_fn,
        concat_states,
        key_state,
        compute_min_grad_fn,
        compute_grad_fn,
        hparams,
        pred_state,
        pred_prev_state,
        counter,
    ):
        self.concat_states = concat_states
        self._state = None
        self._bparam = None
        self.opt = OptimizerCreator(
            opt_string=hparams["meta"]["optimizer"], learning_rate=hparams["descent_lr"]
        ).get_optimizer()
        self.objective = objective
        self.dual_objective = dual_objective
        self._lagrange_multiplier = hparams["lagrange_init"]
        self._state_secant_vector = None
        self._state_secant_c2 = None
        self.delta_s = hparams["delta_s"]
        self.descent_period = hparams["descent_period"]
        self.max_norm_state = hparams["max_bounds"]
        self.hparams = hparams
        self.compute_min_grad_fn = compute_min_grad_fn
        self.compute_grad_fn = compute_grad_fn
        self._assign_states()
        self._parc_vec = None
        self.state_stack = dict()
        self.key_state = key_state
        self.pred_state = pred_state
        self.pred_prev_state = pred_prev_state
        self.sphere_radius = hparams["sphere_radius"]
        self.counter = counter
        self.value_fn = value_fn
        # self.data_loader = iter(get_data(dataset=hparams["meta"]['dataset'],
        #                             batch_size=hparams['batch_size'],
        #                             num_workers=hparams['data_workers'],
        #                             train_only=True, test_only=False))
        if hparams["meta"]["dataset"] == "mnist":
            self.data_loader = iter(
                get_mnist_data(
                    batch_size=hparams["batch_size"], resize=hparams["resize_to_small"]
                )
            )
            self.num_batches = meta_mnist(hparams["batch_size"])["num_batches"]
        else:
            self.data_loader = None
            self.num_batches = 1

    def _assign_states(self):
        self._state = self.concat_states[0]
        self._bparam = self.concat_states[1]
        self._state_secant_vector = self.concat_states[2]
        self._state_secant_c2 = self.concat_states[3]

    @staticmethod
    @jit
    def exp_decay(epoch, initial_lrate):
        k = 0.02
        lrate = initial_lrate * np.exp(-k * epoch)
        return lrate

    @staticmethod
    def _perform_perturb_by_projection(
        _state_secant_vector,
        _state_secant_c2,
        key,
        pred_prev_state,
        _state,
        _bparam,
        sphere_radius,
    ):
        ### Secant normal
        n, sample_unravel = pytree_to_vec(
            [_state_secant_vector["state"], _state_secant_vector["bparam"]]
        )
        n = pytree_normalized(n)
        ### sample a random poin in Rn
        # u = tree_map(
        #     lambda a: a + random.uniform(key, a.shape),
        #     pytree_zeros_like(n),
        # )
        print(key)
        u = tree_map(
            lambda a: a + random.normal(key, a.shape),
            pytree_ones_like(n),
        )
        tmp, _ = pytree_to_vec([_state_secant_c2["state"], _state_secant_c2["bparam"]])

        # select a point on the secant normal
        u_0, _ = pytree_to_vec(pred_prev_state)
        # compute projection
        proj_of_u_on_n = projection_affine(len(n), u, n, u_0)

        point_on_plane = u + pytree_sub(tmp, proj_of_u_on_n)  ## state= pred_state + n
        # inv_vec = np.array([-1.0, 1.0])
        parc = pytree_element_mul(
            pytree_normalized(pytree_sub(point_on_plane, tmp)),
            1.0,  # inv_vec[(counter % 2)],
        )
        point_on_plane_2 = tmp + sphere_radius * parc
        print("point on plane ", point_on_plane_2)
        new_sample = sample_unravel(point_on_plane_2)
        state_stack = {}
        state_stack.update({"state": new_sample[0]})
        state_stack.update({"bparam": new_sample[1]})
        _parc_vec = pytree_sub(state_stack, _state_secant_c2)
        return _parc_vec, state_stack

    def _evaluate_perturb(self):
        """Evaluate weather the perturbed vector is orthogonal to secant vector"""

        dot = pytree_dot(
            pytree_normalized(self._parc_vec),
            pytree_normalized(self._state_secant_vector),
        )
        if math.isclose(dot, 0.0, abs_tol=0.25):
            print(f"Perturb was near arc-plane. {dot}")
            self._state = self.state_stack["state"]
            self._bparam = self.state_stack["bparam"]
        else:
            print(f"Perturb was not on arc-plane.{dot}")

    def correction_step(self) -> Tuple:
        """Given the current state optimize to the correct state.

        Returns:
          (state: problem parameters, bparam: continuation parameter) Tuple
        """
        _, key = random.split(random.PRNGKey(self.key_state + npr.randint(1, 100)))
        del _
        quality = 1.0
        N_opt = 10
        stop = False
        corrector_omega = 1.0
        # bparam_grads = pytree_zeros_like(self._bparam)
        print("the radius", self.sphere_radius)
        self._parc_vec, self.state_stack = self._perform_perturb_by_projection(
            self._state_secant_vector,
            self._state_secant_c2,
            key,
            self.pred_prev_state,
            self._state,
            self._bparam,
            self.sphere_radius,
        )
        if self.hparams["_evaluate_perturb"]:
            self._evaluate_perturb()  # does every time

        for j in range(self.descent_period):
            for b_j in range(self.num_batches):
                if self.hparams["meta"]["dataset"] == "mnist":  # TODO: make it generic
                    batch_data = next(self.data_loader)
                else:
                    batch_data = None
                # grads = self.compute_grad_fn(self._state, self._bparam, batch_data)
                # self._state = self.opt.update_params(self._state, grads[0])
                state_grads, bparam_grads = self.compute_min_grad_fn(
                    self._state,
                    self._bparam,
                    self._lagrange_multiplier,
                    self._state_secant_c2,
                    self._state_secant_vector,
                    batch_data,
                    self.delta_s,
                )

                if self.hparams["adaptive"]:
                    self.opt.lr = self.exp_decay(j, self.hparams["natural_lr"])
                    quality = l2_norm(state_grads)  # +l2_norm(bparam_grads)
                    if quality > self.hparams["quality_thresh"]:
                        pass
                        # print(f"quality {quality}, {self.opt.lr}, {bparam_grads} ,{j}")
                    else:
                        if N_opt > (j + 1):  # To get around folds slowly
                            corrector_omega = min(N_opt / (j + 1), 2.0)
                        else:
                            corrector_omega = max(N_opt / (j + 1), 0.5)
                        stop = True
                        print(f"quality {quality} stopping at , {j}th step")
                    state_grads = clip_grads(state_grads, self.hparams["max_clip_grad"])
                    bparam_grads = clip_grads(
                        bparam_grads, self.hparams["max_clip_grad"]
                    )

                self._bparam = self.opt.update_params(self._bparam, bparam_grads, j)
                self._state = self.opt.update_params(self._state, state_grads, j)
                if stop:
                    break
            if stop:
                break

        value = self.value_fn(
            self._state, self._bparam, batch_data
        )  # Todo: why only final batch data
        return self._state, self._bparam, quality, value, corrector_omega
class UnconstrainedCorrector(Corrector):
    """Minimize the objective using gradient based method."""
    def __init__(self, objective, concat_states, grad_fn, value_fn,
                 accuracy_fn, hparams, dataset_tuple):
        self.concat_states = concat_states
        self._state = None
        self._bparam = None
        self.opt = OptimizerCreator(
            opt_string=hparams["meta"]["optimizer"],
            learning_rate=hparams["natural_lr"]).get_optimizer()
        self.objective = objective
        self.accuracy_fn = accuracy_fn
        self.warmup_period = hparams["warmup_period"]
        self.hparams = hparams
        self.grad_fn = grad_fn
        self.value_fn = value_fn
        self._assign_states()
        if hparams["meta"]["dataset"] == "mnist":
            (self.train_images, self.train_labels, self.test_images,
             self.test_labels) = dataset_tuple
            if hparams["continuation_config"] == 'data':
                # data continuation
                self.data_loader = iter(
                    get_mnist_batch_alter(self.train_images,
                                          self.train_labels,
                                          self.test_images,
                                          self.test_labels,
                                          alter=self._bparam,
                                          batch_size=hparams["batch_size"],
                                          resize=hparams["resize_to_small"],
                                          filter=hparams["filter"]))
            else:
                # model continuation
                self.data_loader = iter(
                    get_mnist_data(batch_size=hparams["batch_size"],
                                   resize=hparams["resize_to_small"],
                                   filter=hparams["filter"])

                    # get_preload_mnist_data(self.train_images, ## TODO: better way to prefetch mnist
                    #                        self.train_labels,
                    #                        self.test_images,
                    #                        self.test_labels,
                    #                          batch_size = hparams["batch_size"],
                    #                          resize = hparams["resize_to_small"],
                    #                         filter = hparams["filter"])
                )

            self.num_batches = meta_mnist(hparams["batch_size"],
                                          hparams["filter"])["num_batches"]
        else:
            self.data_loader = None
            self.num_batches = 1

    def _assign_states(self):
        self._state, self._bparam = self.concat_states

    def correction_step(self) -> Tuple:
        """Given the current state optimize to the correct state.

        Returns:
          (state: problem parameters, bparam: continuation parameter) Tuple
        """

        quality = 1.0
        ma_loss = []
        stop = False
        print("learn_rate", self.opt.lr)
        for k in range(self.warmup_period):
            for b_j in range(self.num_batches):
                batch = next(self.data_loader)
                grads = self.grad_fn(self._state, self._bparam, batch)
                self._state = self.opt.update_params(self._state, grads[0])
                quality = l2_norm(grads)
                value = self.value_fn(self._state, self._bparam, batch)
                ma_loss.append(value)
                self.opt.lr = exp_decay(k, self.hparams["natural_lr"])
                if self.hparams["local_test_measure"] == "norm_gradients":
                    if quality > self.hparams["quality_thresh"]:
                        pass
                        print(f"quality {quality}, {self.opt.lr} ,{k}")
                    else:
                        stop = True
                        print(f"quality {quality} stopping at , {k}th step")
                else:
                    if len(ma_loss) >= 20:
                        tmp_means = running_mean(ma_loss, 10)
                        if math.isclose(
                                tmp_means[-1],
                                tmp_means[-2],
                                abs_tol=self.hparams["loss_tol"],
                        ):
                            print(f"stopping at , {k}th step")
                            stop = True
            if stop:
                print("breaking")
                break

        val_loss = self.value_fn(self._state, self._bparam,
                                 (self.test_images, self.test_labels))
        val_acc = self.accuracy_fn(self._state, self._bparam,
                                   (self.test_images, self.test_labels))
        return self._state, self._bparam, quality, value, val_loss, val_acc