Esempio n. 1
0
    def __init__(
        self,
        state,
        bparam,
        state_0,
        bparam_0,
        counter,
        objective,
        dual_objective,
        hparams,
    ):
        # states
        self._state_wrap = StateVariable(state, counter)
        self._bparam_wrap = StateVariable(
            bparam, counter
        )  # Todo : save tree def, always unlfatten before compute_grads
        self._prev_state = state_0
        self._prev_bparam = bparam_0

        # objectives
        self.objective = objective
        self.dual_objective = dual_objective
        self.value_func = jit(self.objective)

        self.hparams = hparams

        self._value_wrap = StateVariable(
            1.0, counter)  # TODO: fix with a static batch (test/train)
        self._quality_wrap = StateVariable(l2_norm(self._state_wrap.state),
                                           counter)

        # optimizer
        self.opt = OptimizerCreator(
            opt_string=hparams["meta"]["optimizer"],
            learning_rate=hparams["natural_lr"]).get_optimizer()
        self.ascent_opt = OptimizerCreator(
            opt_string=hparams["meta"]["ascent_optimizer"],
            learning_rate=hparams["ascent_lr"],
        ).get_optimizer()

        # every step hparams
        self.continuation_steps = hparams["continuation_steps"]
        self._lagrange_multiplier = hparams["lagrange_init"]

        self._delta_s = hparams["delta_s"]
        self._omega = hparams["omega"]

        # grad functions # should be pure functional
        self.compute_min_grad_fn = jit(grad(self.dual_objective, [0, 1]))
        self.compute_max_grad_fn = jit(grad(self.dual_objective, [2]))
        self.compute_grad_fn = jit(grad(self.objective, [0]))

        # extras
        self.sw = None
        self.state_tree_def = None
        self.bparam_tree_def = None
        self.output_file = hparams["meta"]["output_dir"]
        self.prev_secant_direction = None
Esempio n. 2
0
    def __init__(
        self,
        state,
        bparam,
        state_0,
        bparam_0,
        counter,
        objective,
        dual_objective,
        accuracy_fn,
        hparams,
        key_state,
    ):

        # states
        self._state_wrap = StateVariable(state, counter)
        self._bparam_wrap = StateVariable(
            bparam, counter
        )  # Todo : save tree def, always unlfatten before compute_grads
        self._prev_state = state_0
        self._prev_bparam = bparam_0

        # objectives
        self.objective = objective
        self.dual_objective = dual_objective
        self.accuracy_fn1 = jit(accuracy_fn)
        self.value_func = jit(self.objective)

        self.hparams = hparams
        if hparams["meta"]["dataset"] == "mnist":
            if hparams["continuation_config"] == 'data':
                self.dataset_tuple = mnist_gamma(
                    resize=hparams["resize_to_small"],
                    filter=hparams["filter"])
            else:
                self.dataset_tuple = mnist(resize=hparams["resize_to_small"],
                                           filter=hparams["filter"])

        self._value_wrap = StateVariable(
            0.06, counter)  # TODO: fix with a static batch (test/train)
        self._quality_wrap = StateVariable(
            l2_norm(self._state_wrap.state) / 10, counter)

        # every step hparams
        self.continuation_steps = hparams["continuation_steps"]

        self._delta_s = hparams["delta_s"]
        self._prev_delta_s = hparams["delta_s"]
        self._omega = hparams["omega"]

        # grad functions # should be pure functional
        self.compute_min_grad_fn = jit(grad(self.dual_objective, [0, 1]))
        self.compute_grad_fn = jit(grad(self.objective, [0]))

        # extras
        self.state_tree_def = None
        self.bparam_tree_def = None
        self.output_file = hparams["meta"]["output_dir"]
        self.prev_secant_direction = None
        self.perturb_index = key_state
        self.sw = StateWriter(
            f"{self.output_file}/version_{self.perturb_index}.json")
        self.key_state = key_state + npr.randint(100, 200)
        self.clip_lambda_max = lambda g: np.where(
            (g > self.hparams["lambda_max"]), self.hparams["lambda_max"], g)
        self.clip_lambda_min = lambda g: np.where(
            (g < self.hparams["lambda_min"]), self.hparams["lambda_min"], g)