def __init__(self, state, bparam, state_0, bparam_0, counter, objective, accuracy_fn, hparams): self._state_wrap = StateVariable(state, counter) self._bparam_wrap = StateVariable(bparam, counter) self._prev_state = state_0 self._prev_bparam = bparam_0 self.objective = objective self.accuracy_fn = accuracy_fn self.value_func = jit(self.objective) self._value_wrap = StateVariable(0.005, counter) self._quality_wrap = StateVariable(0.005, counter) self.sw = None self.hparams = hparams self.opt = OptimizerCreator( opt_string=hparams["meta"]["optimizer"], learning_rate=hparams["natural_lr"]).get_optimizer() if hparams["meta"]["dataset"] == "mnist": if hparams["continuation_config"] == 'data': self.dataset_tuple = mnist_gamma( resize=hparams["resize_to_small"], filter=hparams["filter"]) else: self.dataset_tuple = mnist(resize=hparams["resize_to_small"], filter=hparams["filter"]) self.continuation_steps = hparams["continuation_steps"] self.output_file = hparams["meta"]["output_dir"] self._delta_s = hparams["delta_s"] self._prev_delta_s = hparams["delta_s"] self._omega = hparams["omega"] self.grad_fn = jit(grad(self.objective, argnums=[0])) self.prev_secant_direction = None
def __init__(self, state, bparam, counter, objective, accuracy_fn, hparams): self._state_wrap = StateVariable(state, counter) self._bparam_wrap = StateVariable(bparam, counter) self.objective = objective self.value_func = jit(self.objective) self.accuracy_fn = jit(accuracy_fn) self._value_wrap = StateVariable(2.0, counter) self._quality_wrap = StateVariable(0.25, counter) self.sw = None self.hparams = hparams if hparams["meta"]["dataset"] == "mnist": if hparams["continuation_config"] == 'data': self.dataset_tuple = mnist_gamma( resize=hparams["resize_to_small"], filter=hparams["filter"]) else: print("model continuation") self.dataset_tuple = mnist(resize=hparams["resize_to_small"], filter=hparams["filter"]) self.continuation_steps = hparams["continuation_steps"] self.output_file = hparams["meta"]["output_dir"] self._delta_s = hparams["delta_bparams"] self.grad_fn = jit( grad(self.objective, argnums=[0])) # TODO: vmap is not fully supported with stax
def __init__( self, state, bparam, state_0, bparam_0, counter, objective, dual_objective, hparams, ): # states self._state_wrap = StateVariable(state, counter) self._bparam_wrap = StateVariable( bparam, counter ) # Todo : save tree def, always unlfatten before compute_grads self._prev_state = state_0 self._prev_bparam = bparam_0 # objectives self.objective = objective self.dual_objective = dual_objective self.value_func = jit(self.objective) self.hparams = hparams self._value_wrap = StateVariable( 1.0, counter) # TODO: fix with a static batch (test/train) self._quality_wrap = StateVariable(l2_norm(self._state_wrap.state), counter) # optimizer self.opt = OptimizerCreator( opt_string=hparams["meta"]["optimizer"], learning_rate=hparams["natural_lr"]).get_optimizer() self.ascent_opt = OptimizerCreator( opt_string=hparams["meta"]["ascent_optimizer"], learning_rate=hparams["ascent_lr"], ).get_optimizer() # every step hparams self.continuation_steps = hparams["continuation_steps"] self._lagrange_multiplier = hparams["lagrange_init"] self._delta_s = hparams["delta_s"] self._omega = hparams["omega"] # grad functions # should be pure functional self.compute_min_grad_fn = jit(grad(self.dual_objective, [0, 1])) self.compute_max_grad_fn = jit(grad(self.dual_objective, [2])) self.compute_grad_fn = jit(grad(self.objective, [0])) # extras self.sw = None self.state_tree_def = None self.bparam_tree_def = None self.output_file = hparams["meta"]["output_dir"] self.prev_secant_direction = None
class PseudoArcLenContinuation(Continuation): # May be refactor to only one continuation TODO """Pseudo Arc-length Continuation strategy. Composed of secant predictor and constrained corrector""" def __init__( self, state, bparam, state_0, bparam_0, counter, objective, dual_objective, hparams, ): # states self._state_wrap = StateVariable(state, counter) self._bparam_wrap = StateVariable( bparam, counter ) # Todo : save tree def, always unlfatten before compute_grads self._prev_state = state_0 self._prev_bparam = bparam_0 # objectives self.objective = objective self.dual_objective = dual_objective self.value_func = jit(self.objective) self.hparams = hparams self._value_wrap = StateVariable( 1.0, counter) # TODO: fix with a static batch (test/train) self._quality_wrap = StateVariable(l2_norm(self._state_wrap.state), counter) # optimizer self.opt = OptimizerCreator( opt_string=hparams["meta"]["optimizer"], learning_rate=hparams["natural_lr"]).get_optimizer() self.ascent_opt = OptimizerCreator( opt_string=hparams["meta"]["ascent_optimizer"], learning_rate=hparams["ascent_lr"], ).get_optimizer() # every step hparams self.continuation_steps = hparams["continuation_steps"] self._lagrange_multiplier = hparams["lagrange_init"] self._delta_s = hparams["delta_s"] self._omega = hparams["omega"] # grad functions # should be pure functional self.compute_min_grad_fn = jit(grad(self.dual_objective, [0, 1])) self.compute_max_grad_fn = jit(grad(self.dual_objective, [2])) self.compute_grad_fn = jit(grad(self.objective, [0])) # extras self.sw = None self.state_tree_def = None self.bparam_tree_def = None self.output_file = hparams["meta"]["output_dir"] self.prev_secant_direction = None @profile(sort_by="cumulative", lines_to_print=10, strip_dirs=True) def run(self): """Runs the continuation strategy. A continuation strategy that defines how predictor and corrector components of the algorithm interact with the states of the mathematical system. """ self.sw = StateWriter(f"{self.output_file}/version.json") for i in range(self.continuation_steps): self._state_wrap.counter = i self._bparam_wrap.counter = i self._value_wrap.counter = i self.sw.write([ self._state_wrap.get_record(), self._bparam_wrap.get_record(), self._value_wrap.get_record(), ]) concat_states = [ (self._state_wrap.state, self._bparam_wrap.state), (self._prev_state, self._prev_bparam), self.prev_secant_direction, ] predictor = SecantPredictor( concat_states=concat_states, delta_s=self._delta_s, omega=self._omega, net_spacing_param=self.hparams["net_spacing_param"], net_spacing_bparam=self.hparams["net_spacing_bparam"], hparams=self.hparams, ) predictor.prediction_step() self.prev_secant_direction = predictor.secant_direction self._prev_state = self._state_wrap.state self._prev_bparam = self._bparam_wrap.state concat_states = [ predictor.state, predictor.bparam, predictor.secant_direction, predictor.get_secant_concat(), ] del predictor gc.collect() corrector = ConstrainedCorrector( optimizer=self.opt, objective=self.objective, dual_objective=self.dual_objective, lagrange_multiplier=self._lagrange_multiplier, concat_states=concat_states, delta_s=self._delta_s, ascent_opt=self.ascent_opt, compute_min_grad_fn=self.compute_min_grad_fn, compute_max_grad_fn=self.compute_max_grad_fn, compute_grad_fn=self.compute_grad_fn, hparams=self.hparams, ) state, bparam = corrector.correction_step() value = self.value_func(state, bparam) self._state_wrap.state = state self._bparam_wrap.state = bparam self._value_wrap.state = value del corrector gc.collect()
class SecantContinuation(Continuation): """Secant Continuation strategy. Composed of natural predictor and unconstrained corrector""" def __init__(self, state, bparam, state_0, bparam_0, counter, objective, accuracy_fn, hparams): self._state_wrap = StateVariable(state, counter) self._bparam_wrap = StateVariable(bparam, counter) self._prev_state = state_0 self._prev_bparam = bparam_0 self.objective = objective self.accuracy_fn = accuracy_fn self.value_func = jit(self.objective) self._value_wrap = StateVariable(0.005, counter) self._quality_wrap = StateVariable(0.005, counter) self.sw = None self.hparams = hparams self.opt = OptimizerCreator( opt_string=hparams["meta"]["optimizer"], learning_rate=hparams["natural_lr"]).get_optimizer() if hparams["meta"]["dataset"] == "mnist": if hparams["continuation_config"] == 'data': self.dataset_tuple = mnist_gamma( resize=hparams["resize_to_small"], filter=hparams["filter"]) else: self.dataset_tuple = mnist(resize=hparams["resize_to_small"], filter=hparams["filter"]) self.continuation_steps = hparams["continuation_steps"] self.output_file = hparams["meta"]["output_dir"] self._delta_s = hparams["delta_s"] self._prev_delta_s = hparams["delta_s"] self._omega = hparams["omega"] self.grad_fn = jit(grad(self.objective, argnums=[0])) self.prev_secant_direction = None @profile(sort_by="cumulative", lines_to_print=10, strip_dirs=True) def run(self): """Runs the continuation strategy. A continuation strategy that defines how predictor and corrector components of the algorithm interact with the states of the mathematical system. """ self.sw = StateWriter(f"{self.output_file}/version.json") for i in range(self.continuation_steps): if i == 0 and self.hparams["natural_start"]: print(f" unconstrained solver for 1st step") concat_states = [ self._prev_state, pytree_element_add(self._prev_bparam, 0.05), ] corrector = UnconstrainedCorrector( objective=self.objective, concat_states=concat_states, grad_fn=self.grad_fn, value_fn=self.value_func, accuracy_fn=self.accuracy_fn, hparams=self.hparams, dataset_tuple=self.dataset_tuple) state, bparam, quality, value, val_loss = corrector.correction_step( ) self._state_wrap.state = state self._bparam_wrap.state = bparam del corrector, state, bparam, quality, value, concat_states print(self._value_wrap.get_record(), self._bparam_wrap.get_record()) self._state_wrap.counter = i self._bparam_wrap.counter = i self._value_wrap.counter = i self._quality_wrap.counter = i self.sw.write([ self._state_wrap.get_record(), self._bparam_wrap.get_record(), self._value_wrap.get_record(), self._quality_wrap.get_record(), ]) concat_states = [ (self._prev_state, self._prev_bparam), (self._state_wrap.state, self._bparam_wrap.state), self.prev_secant_direction, ] predictor = SecantPredictor( concat_states=concat_states, delta_s=self._delta_s, prev_delta_s=self._prev_delta_s, omega=self._omega, net_spacing_param=self.hparams["net_spacing_param"], net_spacing_bparam=self.hparams["net_spacing_bparam"], hparams=self.hparams, ) predictor.prediction_step() self.prev_secant_direction = predictor.secant_direction concat_states = [predictor.state, predictor.bparam] del predictor gc.collect() corrector = UnconstrainedCorrector( objective=self.objective, concat_states=concat_states, grad_fn=self.grad_fn, value_fn=self.value_func, accuracy_fn=self.accuracy_fn, hparams=self.hparams, dataset_tuple=self.dataset_tuple) state, bparam, quality, value, val_loss, val_acc = corrector.correction_step( ) corrector_omega = 0.005 # why fixed check TODO self._prev_delta_s = self._delta_s self._delta_s = corrector_omega * self._delta_s self._delta_s = min(self._delta_s, self.hparams["max_arc_len"]) self._delta_s = max(self._delta_s, self.hparams["min_arc_len"]) self._state_wrap.state = state self._bparam_wrap.state = bparam self._value_wrap.state = value self._quality_wrap.state = quality del corrector del concat_states gc.collect() if self._bparam_wrap.state[0] >= self.hparams["lambda_max"]: self.sw.write([ self._state_wrap.get_record(), self._bparam_wrap.get_record(), self._value_wrap.get_record(), self._quality_wrap.get_record(), ]) break mlflow.log_metrics( { "train_loss": float(self._value_wrap.state), "delta_s": float(self._delta_s), "norm grads": float(self._quality_wrap.state), "val_loss": float(val_loss) }, i)
def __init__( self, state, bparam, state_0, bparam_0, counter, objective, dual_objective, accuracy_fn, hparams, key_state, ): # states self._state_wrap = StateVariable(state, counter) self._bparam_wrap = StateVariable( bparam, counter ) # Todo : save tree def, always unlfatten before compute_grads self._prev_state = state_0 self._prev_bparam = bparam_0 # objectives self.objective = objective self.dual_objective = dual_objective self.accuracy_fn1 = jit(accuracy_fn) self.value_func = jit(self.objective) self.hparams = hparams if hparams["meta"]["dataset"] == "mnist": if hparams["continuation_config"] == 'data': self.dataset_tuple = mnist_gamma( resize=hparams["resize_to_small"], filter=hparams["filter"]) else: self.dataset_tuple = mnist(resize=hparams["resize_to_small"], filter=hparams["filter"]) self._value_wrap = StateVariable( 0.06, counter) # TODO: fix with a static batch (test/train) self._quality_wrap = StateVariable( l2_norm(self._state_wrap.state) / 10, counter) # every step hparams self.continuation_steps = hparams["continuation_steps"] self._delta_s = hparams["delta_s"] self._prev_delta_s = hparams["delta_s"] self._omega = hparams["omega"] # grad functions # should be pure functional self.compute_min_grad_fn = jit(grad(self.dual_objective, [0, 1])) self.compute_grad_fn = jit(grad(self.objective, [0])) # extras self.state_tree_def = None self.bparam_tree_def = None self.output_file = hparams["meta"]["output_dir"] self.prev_secant_direction = None self.perturb_index = key_state self.sw = StateWriter( f"{self.output_file}/version_{self.perturb_index}.json") self.key_state = key_state + npr.randint(100, 200) self.clip_lambda_max = lambda g: np.where( (g > self.hparams["lambda_max"]), self.hparams["lambda_max"], g) self.clip_lambda_min = lambda g: np.where( (g < self.hparams["lambda_min"]), self.hparams["lambda_min"], g)
class PerturbedPseudoArcLenFixedContinuation(Continuation): """Noisy Pseudo Arc-length Continuation strategy. Composed of secant predictor and noisy constrained corrector""" def __init__( self, state, bparam, state_0, bparam_0, counter, objective, dual_objective, accuracy_fn, hparams, key_state, ): # states self._state_wrap = StateVariable(state, counter) self._bparam_wrap = StateVariable( bparam, counter ) # Todo : save tree def, always unlfatten before compute_grads self._prev_state = state_0 self._prev_bparam = bparam_0 # objectives self.objective = objective self.dual_objective = dual_objective self.accuracy_fn1 = jit(accuracy_fn) self.value_func = jit(self.objective) self.hparams = hparams if hparams["meta"]["dataset"] == "mnist": if hparams["continuation_config"] == 'data': self.dataset_tuple = mnist_gamma( resize=hparams["resize_to_small"], filter=hparams["filter"]) else: self.dataset_tuple = mnist(resize=hparams["resize_to_small"], filter=hparams["filter"]) self._value_wrap = StateVariable( 0.06, counter) # TODO: fix with a static batch (test/train) self._quality_wrap = StateVariable( l2_norm(self._state_wrap.state) / 10, counter) # every step hparams self.continuation_steps = hparams["continuation_steps"] self._delta_s = hparams["delta_s"] self._prev_delta_s = hparams["delta_s"] self._omega = hparams["omega"] # grad functions # should be pure functional self.compute_min_grad_fn = jit(grad(self.dual_objective, [0, 1])) self.compute_grad_fn = jit(grad(self.objective, [0])) # extras self.state_tree_def = None self.bparam_tree_def = None self.output_file = hparams["meta"]["output_dir"] self.prev_secant_direction = None self.perturb_index = key_state self.sw = StateWriter( f"{self.output_file}/version_{self.perturb_index}.json") self.key_state = key_state + npr.randint(100, 200) self.clip_lambda_max = lambda g: np.where( (g > self.hparams["lambda_max"]), self.hparams["lambda_max"], g) self.clip_lambda_min = lambda g: np.where( (g < self.hparams["lambda_min"]), self.hparams["lambda_min"], g) @profile(sort_by="cumulative", lines_to_print=10, strip_dirs=True) def run(self): """Runs the continuation strategy. A continuation strategy that defines how predictor and corrector components of the algorithm interact with the states of the mathematical system. """ for i in range(self.continuation_steps): self._state_wrap.counter = i self._bparam_wrap.counter = i self._value_wrap.counter = i self._quality_wrap.counter = i if i == 0 and self.hparams["natural_start"]: print(f" unconstrained solver for 1st step") concat_states = [ self._prev_state, pytree_element_add(self._prev_bparam, 0.03), ] corrector = UnconstrainedCorrector( objective=self.objective, concat_states=concat_states, grad_fn=self.compute_grad_fn, value_fn=self.value_func, accuracy_fn=self.accuracy_fn1, hparams=self.hparams, dataset_tuple=self.dataset_tuple, ) state, bparam, quality, value, val_loss, val_acc = corrector.correction_step( ) if self.hparams[ "double_natural_start"]: # TODO: refactor natural and double natural start self._prev_state = state self._prev_bparam = bparam print(f" unconstrained solver for 2nd step") concat_states = [ self._prev_state, pytree_element_add(self._prev_bparam, 0.07), ] corrector = UnconstrainedCorrector( objective=self.objective, concat_states=concat_states, grad_fn=self.compute_grad_fn, value_fn=self.value_func, accuracy_fn=self.accuracy_fn1, hparams=self.hparams, dataset_tuple=self.dataset_tuple, ) state, bparam, quality, value, val_loss, val_acc = corrector.correction_step( ) self._state_wrap.state = state self._bparam_wrap.state = bparam print( "delta_s", self._value_wrap.get_record(), self._bparam_wrap.get_record(), self._delta_s, ) self.sw.write([ self._state_wrap.get_record(), self._bparam_wrap.get_record(), self._value_wrap.get_record(), self._quality_wrap.get_record(), ]) concat_states = [ (self._prev_state, self._prev_bparam), (self._state_wrap.state, self._bparam_wrap.state), self.prev_secant_direction, ] predictor = SecantPredictor( concat_states=concat_states, delta_s=self._delta_s, prev_delta_s=self._prev_delta_s, omega=self._omega, net_spacing_param=self.hparams["net_spacing_param"], net_spacing_bparam=self.hparams["net_spacing_bparam"], hparams=self.hparams, ) predictor.prediction_step() self.prev_secant_direction = predictor.secant_direction self.hparams["sphere_radius"] = ( self.hparams["sphere_radius_m"] * self._delta_s ) # l2_norm(predictor.secant_direction) mlflow.log_metric(f"sphere_radius{self.perturb_index}", self.hparams["sphere_radius"], i) mlflow.log_metric(f"delta_s{self.perturb_index}", self._delta_s, i) concat_states = [ predictor.state, predictor.bparam, predictor.secant_direction, { "state": predictor.state, "bparam": predictor.bparam }, ] corrector = PerturbedFixedCorrecter( objective=self.objective, dual_objective=self.dual_objective, accuracy_fn1=self.accuracy_fn1, value_fn=self.value_func, concat_states=concat_states, key_state=self.key_state, compute_min_grad_fn=self.compute_min_grad_fn, compute_grad_fn=self.compute_grad_fn, hparams=self.hparams, delta_s=self._delta_s, pred_state=[self._state_wrap.state, self._bparam_wrap.state], pred_prev_state=[ self._state_wrap.state, self._bparam_wrap.state ], counter=self.continuation_steps, dataset_tuple=self.dataset_tuple, ) self._prev_state = copy.deepcopy(self._state_wrap.state) self._prev_bparam = copy.deepcopy(self._bparam_wrap.state) ( state, bparam, quality, value, val_loss, val_acc, corrector_omega, ) = (corrector.correction_step() ) # TODO: make predictor corrector similar api's # TODO: Enable MLFlow bparam = tree_map(self.clip_lambda_max, bparam) bparam = tree_map(self.clip_lambda_min, bparam) self._state_wrap.state = state self._bparam_wrap.state = bparam self._value_wrap.state = value self._quality_wrap.state = quality # self._omega = corrector_omega self._prev_delta_s = self._delta_s self._delta_s = corrector_omega * self._delta_s self._delta_s = min(self._delta_s, self.hparams["max_arc_len"]) self._delta_s = max(self._delta_s, self.hparams["min_arc_len"]) if (bparam[0] >= self.hparams["lambda_max"]) or ( bparam[0] <= self.hparams["lambda_min"]): self.sw.write([ self._state_wrap.get_record(), self._bparam_wrap.get_record(), self._value_wrap.get_record(), self._quality_wrap.get_record(), ]) break mlflow.log_metrics( { f"train_loss{self.perturb_index}": float(self._value_wrap.state), f"delta_s{self.perturb_index}": float(self._delta_s), f"norm grads{self.perturb_index}": float(self._quality_wrap.state), f"val_loss{self.perturb_index}": float(val_loss), f"val_acc{self.perturb_index}": float(val_acc), f"corrector_omega{self.perturb_index}": float(corrector_omega) }, i)
class NaturalContinuation(Continuation): """Natural Continuation strategy. Composed of natural predictor and unconstrained corrector""" def __init__(self, state, bparam, counter, objective, accuracy_fn, hparams): self._state_wrap = StateVariable(state, counter) self._bparam_wrap = StateVariable(bparam, counter) self.objective = objective self.value_func = jit(self.objective) self.accuracy_fn = jit(accuracy_fn) self._value_wrap = StateVariable(2.0, counter) self._quality_wrap = StateVariable(0.25, counter) self.sw = None self.hparams = hparams if hparams["meta"]["dataset"] == "mnist": if hparams["continuation_config"] == 'data': self.dataset_tuple = mnist_gamma( resize=hparams["resize_to_small"], filter=hparams["filter"]) else: print("model continuation") self.dataset_tuple = mnist(resize=hparams["resize_to_small"], filter=hparams["filter"]) self.continuation_steps = hparams["continuation_steps"] self.output_file = hparams["meta"]["output_dir"] self._delta_s = hparams["delta_bparams"] self.grad_fn = jit( grad(self.objective, argnums=[0])) # TODO: vmap is not fully supported with stax @profile(sort_by="cumulative", lines_to_print=10, strip_dirs=True) def run(self): """Runs the continuation strategy. A continuation strategy that defines how predictor and corrector components of the algorithm interact with the states of the mathematical system. """ self.sw = StateWriter(f"{self.output_file}/version.json") for i in range(self.continuation_steps): print(self._value_wrap.get_record(), self._bparam_wrap.get_record()) self._state_wrap.counter = i self._bparam_wrap.counter = i self._value_wrap.counter = i self._quality_wrap.counter = i self.sw.write([ self._state_wrap.get_record(), self._bparam_wrap.get_record(), self._value_wrap.get_record(), self._quality_wrap.get_record(), ]) concat_states = [self._state_wrap.state, self._bparam_wrap.state] predictor = NaturalPredictor(concat_states=concat_states, delta_s=self._delta_s) predictor.prediction_step() concat_states = [predictor.state, predictor.bparam] del predictor gc.collect() corrector = UnconstrainedCorrector( objective=self.objective, concat_states=concat_states, grad_fn=self.grad_fn, value_fn=self.value_func, accuracy_fn=self.accuracy_fn, hparams=self.hparams, dataset_tuple=self.dataset_tuple, ) state, bparam, quality, value, val_loss, val_acc = corrector.correction_step( ) clip_lambda = lambda g: np.where((g > self.hparams["lambda_max"]), self.hparams["lambda_max"], g) bparam = tree_map(clip_lambda, bparam) clip_lambda = lambda g: np.where((g < self.hparams["lambda_min"]), self.hparams["lambda_min"], g) bparam = tree_map(clip_lambda, bparam) self._state_wrap.state = state self._bparam_wrap.state = bparam self._value_wrap.state = value self._quality_wrap.state = quality del corrector gc.collect() if self._bparam_wrap.state[0] >= self.hparams["lambda_max"]: self.sw.write([ self._state_wrap.get_record(), self._bparam_wrap.get_record(), self._value_wrap.get_record(), self._quality_wrap.get_record(), ]) break mlflow.log_metrics( { "train_loss": float(self._value_wrap.state), "delta_s": float(self._delta_s), "norm grads": float(self._quality_wrap.state), "val_loss": float(val_loss), "val_acc": float(val_acc) }, i)