def __init__(self, objective, dual_objective, accuracy_fn1, value_fn, concat_states, key_state, compute_min_grad_fn, compute_grad_fn, hparams, delta_s, pred_state, pred_prev_state, counter, dataset_tuple): self.concat_states = concat_states self._state = None self._bparam = None self.opt = OptimizerCreator( opt_string=hparams["meta"]["optimizer"], learning_rate=hparams["descent_lr"]).get_optimizer() self.objective = objective self.dual_objective = dual_objective self._lagrange_multiplier = hparams["lagrange_init"] self._state_secant_vector = None self._state_secant_c2 = None self.delta_s = delta_s self.descent_period = hparams["descent_period"] self.max_norm_state = hparams["max_bounds"] self.hparams = hparams self.compute_min_grad_fn = compute_min_grad_fn self.compute_grad_fn = compute_grad_fn self._assign_states() self._parc_vec = None self.state_stack = dict() self.key_state = key_state self.pred_state = pred_state self.pred_prev_state = pred_prev_state self.sphere_radius = hparams["sphere_radius"] self.counter = counter self.value_fn = value_fn self.accuracy_fn1 = accuracy_fn1 self.dataset_tuple = dataset_tuple if hparams["meta"]["dataset"] == "mnist": (self.train_images, self.train_labels, self.test_images, self.test_labels) = dataset_tuple if hparams["continuation_config"] == 'data': # data continuation self.data_loader = iter( get_mnist_batch_alter(self.train_images, self.train_labels, self.test_images, self.test_labels, alter=self._bparam, batch_size=hparams["batch_size"], resize=hparams["resize_to_small"], filter=hparams["filter"])) else: # model continuation self.data_loader = iter( get_mnist_data(batch_size=hparams["batch_size"], resize=hparams["resize_to_small"], filter=hparams["filter"])) self.num_batches = meta_mnist(hparams["batch_size"], hparams["filter"])["num_batches"] else: self.data_loader = None self.num_batches = 1
def __init__( self, state, bparam, state_0, bparam_0, counter, objective, dual_objective, hparams, ): # states self._state_wrap = StateVariable(state, counter) self._bparam_wrap = StateVariable( bparam, counter ) # Todo : save tree def, always unlfatten before compute_grads self._prev_state = state_0 self._prev_bparam = bparam_0 # objectives self.objective = objective self.dual_objective = dual_objective self.value_func = jit(self.objective) self.hparams = hparams self._value_wrap = StateVariable( 1.0, counter) # TODO: fix with a static batch (test/train) self._quality_wrap = StateVariable(l2_norm(self._state_wrap.state), counter) # optimizer self.opt = OptimizerCreator( opt_string=hparams["meta"]["optimizer"], learning_rate=hparams["natural_lr"]).get_optimizer() self.ascent_opt = OptimizerCreator( opt_string=hparams["meta"]["ascent_optimizer"], learning_rate=hparams["ascent_lr"], ).get_optimizer() # every step hparams self.continuation_steps = hparams["continuation_steps"] self._lagrange_multiplier = hparams["lagrange_init"] self._delta_s = hparams["delta_s"] self._omega = hparams["omega"] # grad functions # should be pure functional self.compute_min_grad_fn = jit(grad(self.dual_objective, [0, 1])) self.compute_max_grad_fn = jit(grad(self.dual_objective, [2])) self.compute_grad_fn = jit(grad(self.objective, [0])) # extras self.sw = None self.state_tree_def = None self.bparam_tree_def = None self.output_file = hparams["meta"]["output_dir"] self.prev_secant_direction = None
def __init__( self, objective, dual_objective, value_fn, concat_states, key_state, compute_min_grad_fn, compute_grad_fn, hparams, pred_state, pred_prev_state, counter, ): self.concat_states = concat_states self._state = None self._bparam = None self.opt = OptimizerCreator( opt_string=hparams["meta"]["optimizer"], learning_rate=hparams["descent_lr"] ).get_optimizer() self.objective = objective self.dual_objective = dual_objective self._lagrange_multiplier = hparams["lagrange_init"] self._state_secant_vector = None self._state_secant_c2 = None self.delta_s = hparams["delta_s"] self.descent_period = hparams["descent_period"] self.max_norm_state = hparams["max_bounds"] self.hparams = hparams self.compute_min_grad_fn = compute_min_grad_fn self.compute_grad_fn = compute_grad_fn self._assign_states() self._parc_vec = None self.state_stack = dict() self.key_state = key_state self.pred_state = pred_state self.pred_prev_state = pred_prev_state self.sphere_radius = hparams["sphere_radius"] self.counter = counter self.value_fn = value_fn # self.data_loader = iter(get_data(dataset=hparams["meta"]['dataset'], # batch_size=hparams['batch_size'], # num_workers=hparams['data_workers'], # train_only=True, test_only=False)) if hparams["meta"]["dataset"] == "mnist": self.data_loader = iter( get_mnist_data( batch_size=hparams["batch_size"], resize=hparams["resize_to_small"] ) ) self.num_batches = meta_mnist(hparams["batch_size"])["num_batches"] else: self.data_loader = None self.num_batches = 1
def __init__(self, objective, concat_states, grad_fn, value_fn, accuracy_fn, hparams, dataset_tuple): self.concat_states = concat_states self._state = None self._bparam = None self.opt = OptimizerCreator( opt_string=hparams["meta"]["optimizer"], learning_rate=hparams["natural_lr"]).get_optimizer() self.objective = objective self.accuracy_fn = accuracy_fn self.warmup_period = hparams["warmup_period"] self.hparams = hparams self.grad_fn = grad_fn self.value_fn = value_fn self._assign_states() if hparams["meta"]["dataset"] == "mnist": (self.train_images, self.train_labels, self.test_images, self.test_labels) = dataset_tuple if hparams["continuation_config"] == 'data': # data continuation self.data_loader = iter( get_mnist_batch_alter(self.train_images, self.train_labels, self.test_images, self.test_labels, alter=self._bparam, batch_size=hparams["batch_size"], resize=hparams["resize_to_small"], filter=hparams["filter"])) else: # model continuation self.data_loader = iter( get_mnist_data(batch_size=hparams["batch_size"], resize=hparams["resize_to_small"], filter=hparams["filter"]) # get_preload_mnist_data(self.train_images, ## TODO: better way to prefetch mnist # self.train_labels, # self.test_images, # self.test_labels, # batch_size = hparams["batch_size"], # resize = hparams["resize_to_small"], # filter = hparams["filter"]) ) self.num_batches = meta_mnist(hparams["batch_size"], hparams["filter"])["num_batches"] else: self.data_loader = None self.num_batches = 1
def __init__(self, state, bparam, state_0, bparam_0, counter, objective, accuracy_fn, hparams): self._state_wrap = StateVariable(state, counter) self._bparam_wrap = StateVariable(bparam, counter) self._prev_state = state_0 self._prev_bparam = bparam_0 self.objective = objective self.accuracy_fn = accuracy_fn self.value_func = jit(self.objective) self._value_wrap = StateVariable(0.005, counter) self._quality_wrap = StateVariable(0.005, counter) self.sw = None self.hparams = hparams self.opt = OptimizerCreator( opt_string=hparams["meta"]["optimizer"], learning_rate=hparams["natural_lr"]).get_optimizer() if hparams["meta"]["dataset"] == "mnist": if hparams["continuation_config"] == 'data': self.dataset_tuple = mnist_gamma( resize=hparams["resize_to_small"], filter=hparams["filter"]) else: self.dataset_tuple = mnist(resize=hparams["resize_to_small"], filter=hparams["filter"]) self.continuation_steps = hparams["continuation_steps"] self.output_file = hparams["meta"]["output_dir"] self._delta_s = hparams["delta_s"] self._prev_delta_s = hparams["delta_s"] self._omega = hparams["omega"] self.grad_fn = jit(grad(self.objective, argnums=[0])) self.prev_secant_direction = None
hparams["meta"]["output_dir"] = artifact_uri2 file_name = f"{artifact_uri2}/version.jsonl" sw = StateWriter(file_name=file_name) data_loader = iter( get_mnist_data(batch_size=hparams["batch_size"], resize=True, filter=hparams['filter'])) num_batches = meta_mnist(batch_size=hparams["batch_size"], filter=hparams['filter'])["num_batches"] print(f"num of bathces: {num_batches}") compute_grad_fn = jit(grad(problem.objective, [0])) opt = OptimizerCreator( hparams["meta"]["optimizer"], learning_rate=hparams["natural_lr"]).get_optimizer() ma_loss = [] for epoch in range(hparams["warmup_period"]): for b_j in range(num_batches): batch = next(data_loader) ae_grads = compute_grad_fn(ae_params, batch) ae_params = opt.update_params(ae_params, ae_grads[0], step_index=epoch) loss = problem.objective(ae_params, batch) ma_loss.append(loss) print(f"loss:{loss} norm:{l2_norm(ae_grads)}") #opt.lr = exp_decay(epoch, hparams["natural_lr"]) mlflow.log_metrics( {
hparams["meta"]["output_dir"] = artifact_uri2 file_name = f"{artifact_uri2}/version.jsonl" sw = StateWriter(file_name=file_name) data_loader = iter( get_mnist_data(batch_size=hparams["batch_size"], resize=True, filter=hparams['filter'])) num_batches = meta_mnist(batch_size=hparams["batch_size"], filter=hparams['filter'])["num_batches"] print(f"num of bathces: {num_batches}") compute_grad_fn = jit(grad(problem.objective, [0, 1])) opt = OptimizerCreator( hparams["meta"]["optimizer"], learning_rate=hparams["natural_lr"]).get_optimizer() ma_loss = [] for epoch in range(hparams["warmup_period"]): for b_j in range(num_batches): batch = next(data_loader) ae_grads, b_grads = compute_grad_fn(ae_params, bparam, batch) grads = ae_grads ae_params = opt.update_params(ae_params, ae_grads, step_index=epoch) bparam = opt.update_params(bparam, b_grads, step_index=epoch) loss = problem.objective(ae_params, bparam, batch) ma_loss.append(loss) print(f"loss:{loss} norm:{l2_norm(grads)}") opt.lr = exp_decay(epoch, hparams["natural_lr"])
class PerturbedFixedCorrecter(Corrector): """Minimize the objective using gradient based method along with some constraint and noise""" def __init__( self, objective, dual_objective, value_fn, concat_states, key_state, compute_min_grad_fn, compute_grad_fn, hparams, delta_s, pred_state, pred_prev_state, counter, ): self.concat_states = concat_states self._state = None self._bparam = None self.opt = OptimizerCreator( opt_string=hparams["meta"]["optimizer"], learning_rate=hparams["descent_lr"]).get_optimizer() self.objective = objective self.dual_objective = dual_objective self._lagrange_multiplier = hparams["lagrange_init"] self._state_secant_vector = None self._state_secant_c2 = None self.delta_s = delta_s self.descent_period = hparams["descent_period"] self.max_norm_state = hparams["max_bounds"] self.hparams = hparams self.compute_min_grad_fn = compute_min_grad_fn self.compute_grad_fn = compute_grad_fn self._assign_states() self._parc_vec = None self.state_stack = dict() self.key_state = key_state self.pred_state = pred_state self.pred_prev_state = pred_prev_state self.sphere_radius = hparams["sphere_radius"] self.counter = counter self.value_fn = value_fn if hparams["meta"]["dataset"] == "mnist": self.data_loader = iter( get_mnist_data(batch_size=hparams["batch_size"], resize=hparams["resize_to_small"], filter=hparams["filter"])) self.num_batches = meta_mnist(hparams["batch_size"], hparams["filter"])["num_batches"] else: self.num_batches = 1 def _assign_states(self): self._state = self.concat_states[0] self._bparam = self.concat_states[1] self._state_secant_vector = self.concat_states[2] self._state_secant_c2 = self.concat_states[3] @staticmethod @jit def exp_decay(epoch, initial_lrate): k = 0.02 lrate = initial_lrate * np.exp(-k * epoch) return lrate @staticmethod @jit def _perform_perturb_by_projection( _state_secant_vector, _state_secant_c2, key, pred_prev_state, _state, _bparam, counter, sphere_radius, batch_data, ): ### Secant normal n, sample_unravel = pytree_to_vec( [_state_secant_vector["state"], _state_secant_vector["bparam"]]) n = pytree_normalized(n) ### sample a random poin in Rn # u = tree_map( # lambda a: a + random.uniform(key, a.shape), # pytree_zeros_like(n), # ) u = tree_map( lambda a: a + random.normal(key, a.shape), pytree_ones_like(n), ) tmp, _ = pytree_to_vec( [_state_secant_c2["state"], _state_secant_c2["bparam"]]) # select a point on the secant normal u_0, _ = pytree_to_vec(pred_prev_state) # compute projection proj_of_u_on_n = projection_affine(len(n), u, n, u_0) point_on_plane = u + pytree_sub( tmp, proj_of_u_on_n) ## state= pred_state + n #noise = random.uniform(key, [1], minval=-0.003, maxval=0.03) inv_vec = np.array([-1.0, 1.0]) parc = pytree_element_mul( pytree_normalized(pytree_sub(point_on_plane, tmp)), inv_vec[(counter % 2)], ) point_on_plane_2 = tmp + sphere_radius * parc new_sample = sample_unravel(point_on_plane_2) state_stack = {} state_stack.update({"state": new_sample[0]}) state_stack.update({"bparam": new_sample[1]}) _parc_vec = pytree_sub(state_stack, _state_secant_c2) return _parc_vec, state_stack def _evaluate_perturb(self): """Evaluate weather the perturbed vector is orthogonal to secant vector""" dot = pytree_dot( pytree_normalized(self._parc_vec), pytree_normalized(self._state_secant_vector), ) if math.isclose(dot, 0.0, abs_tol=0.15): print(f"Perturb was near arc-plane. {dot}") else: print(f"Perturb was not on arc-plane.{dot}") def correction_step(self) -> Tuple: """Given the current state optimize to the correct state. Returns: (state: problem parameters, bparam: continuation parameter) Tuple """ quality = 1.0 if self.hparams["meta"]["dataset"] == "mnist": # TODO: make it generic batch_data = next(self.data_loader) else: batch_data = None ants_norm_grads = [5.0 for _ in range(self.hparams["n_wall_ants"])] ants_loss_values = [5.0 for _ in range(self.hparams["n_wall_ants"])] ants_state = [self._state for _ in range(self.hparams["n_wall_ants"])] ants_bparam = [ self._bparam for _ in range(self.hparams["n_wall_ants"]) ] for i_n in range(self.hparams["n_wall_ants"]): corrector_omega = 1.0 stop = False _, key = random.split( random.PRNGKey(self.key_state + i_n + npr.randint(1, (i_n + 1) * 10))) del _ self._parc_vec, self.state_stack = self._perform_perturb_by_projection( self._state_secant_vector, self._state_secant_c2, key, self.pred_prev_state, self._state, self._bparam, i_n, self.sphere_radius, batch_data, ) if self.hparams["_evaluate_perturb"]: self._evaluate_perturb() # does every time ants_state[i_n] = self.state_stack["state"] ants_bparam[i_n] = self.state_stack["bparam"] D_values = [] print(f"num_batches", self.num_batches) for j_epoch in range(self.descent_period): for b_j in range(self.num_batches): #alternate # grads = self.compute_grad_fn(self._state, self._bparam, batch_data) # self._state = self.opt.update_params(self._state, grads[0]) state_grads, bparam_grads = self.compute_min_grad_fn( ants_state[i_n], ants_bparam[i_n], self._lagrange_multiplier, self._state_secant_c2, self._state_secant_vector, batch_data, self.delta_s, ) if self.hparams["adaptive"]: self.opt.lr = self.exp_decay( j_epoch, self.hparams["natural_lr"]) quality = l2_norm(state_grads) #l2_norm(bparam_grads) if self.hparams[ "local_test_measure"] == "norm_gradients": if quality > self.hparams["quality_thresh"]: pass print( f"quality {quality}, {self.opt.lr}, {bparam_grads} ,{j_epoch}" ) else: stop = True print( f"quality {quality} stopping at , {j_epoch}th step" ) else: print( f"quality {quality}, {bparam_grads} ,{j_epoch}" ) if len(D_values) >= 20: tmp_means = running_mean(D_values, 10) if (math.isclose( tmp_means[-1], tmp_means[-2], abs_tol=self.hparams["loss_tol"])): print( f"stopping at , {j_epoch}th step, {ants_bparam[i_n]} bparam" ) stop = True state_grads = clip_grads(state_grads, self.hparams["max_clip_grad"]) bparam_grads = clip_grads( bparam_grads, self.hparams["max_clip_grad"]) if self.hparams["guess_ant_steps"] >= ( j_epoch + 1): # To get around folds slowly corrector_omega = min( self.hparams["guess_ant_steps"] / (j_epoch + 1), 1.5) else: corrector_omega = max( self.hparams["guess_ant_steps"] / (j_epoch + 1), 0.05) ants_state[i_n] = self.opt.update_params( ants_state[i_n], state_grads, j_epoch) ants_bparam[i_n] = self.opt.update_params( ants_bparam[i_n], bparam_grads, j_epoch) ants_loss_values[i_n] = self.value_fn( ants_state[i_n], ants_bparam[i_n], batch_data) D_values.append(ants_loss_values[i_n]) ants_norm_grads[i_n] = quality # if stop: # break if (self.hparams["meta"]["dataset"] == "mnist" ): # TODO: make it generic batch_data = next(self.data_loader) if stop: break # ants_group = dict(enumerate(grouper(ants_state, tolerence), 1)) # print(f"Number of groups: {len(ants_group)}") cheapest_index = get_cheapest_ant( ants_norm_grads, ants_loss_values, local_test=self.hparams["local_test_measure"]) self._state = ants_state[cheapest_index] self._bparam = ants_bparam[cheapest_index] value = self.value_fn(self._state, self._bparam, batch_data) # Todo: why only final batch data _, _, test_images, test_labels = mnist(permute_train=False, resize=True, filter=self.hparams["filter"]) del _ val_loss = self.value_fn(self._state, self._bparam, (test_images, test_labels)) print(f"val loss: {val_loss}") return self._state, self._bparam, quality, value, val_loss, corrector_omega
class PerturbedFixedCorrecter(Corrector): """Minimize the objective using gradient based method along with some constraint and noise""" def __init__( self, objective, dual_objective, value_fn, concat_states, key_state, compute_min_grad_fn, compute_grad_fn, hparams, pred_state, pred_prev_state, counter, ): self.concat_states = concat_states self._state = None self._bparam = None self.opt = OptimizerCreator( opt_string=hparams["meta"]["optimizer"], learning_rate=hparams["descent_lr"] ).get_optimizer() self.objective = objective self.dual_objective = dual_objective self._lagrange_multiplier = hparams["lagrange_init"] self._state_secant_vector = None self._state_secant_c2 = None self.delta_s = hparams["delta_s"] self.descent_period = hparams["descent_period"] self.max_norm_state = hparams["max_bounds"] self.hparams = hparams self.compute_min_grad_fn = compute_min_grad_fn self.compute_grad_fn = compute_grad_fn self._assign_states() self._parc_vec = None self.state_stack = dict() self.key_state = key_state self.pred_state = pred_state self.pred_prev_state = pred_prev_state self.sphere_radius = hparams["sphere_radius"] self.counter = counter self.value_fn = value_fn # self.data_loader = iter(get_data(dataset=hparams["meta"]['dataset'], # batch_size=hparams['batch_size'], # num_workers=hparams['data_workers'], # train_only=True, test_only=False)) if hparams["meta"]["dataset"] == "mnist": self.data_loader = iter( get_mnist_data( batch_size=hparams["batch_size"], resize=hparams["resize_to_small"] ) ) self.num_batches = meta_mnist(hparams["batch_size"])["num_batches"] else: self.data_loader = None self.num_batches = 1 def _assign_states(self): self._state = self.concat_states[0] self._bparam = self.concat_states[1] self._state_secant_vector = self.concat_states[2] self._state_secant_c2 = self.concat_states[3] @staticmethod @jit def exp_decay(epoch, initial_lrate): k = 0.02 lrate = initial_lrate * np.exp(-k * epoch) return lrate @staticmethod def _perform_perturb_by_projection( _state_secant_vector, _state_secant_c2, key, pred_prev_state, _state, _bparam, sphere_radius, ): ### Secant normal n, sample_unravel = pytree_to_vec( [_state_secant_vector["state"], _state_secant_vector["bparam"]] ) n = pytree_normalized(n) ### sample a random poin in Rn # u = tree_map( # lambda a: a + random.uniform(key, a.shape), # pytree_zeros_like(n), # ) print(key) u = tree_map( lambda a: a + random.normal(key, a.shape), pytree_ones_like(n), ) tmp, _ = pytree_to_vec([_state_secant_c2["state"], _state_secant_c2["bparam"]]) # select a point on the secant normal u_0, _ = pytree_to_vec(pred_prev_state) # compute projection proj_of_u_on_n = projection_affine(len(n), u, n, u_0) point_on_plane = u + pytree_sub(tmp, proj_of_u_on_n) ## state= pred_state + n # inv_vec = np.array([-1.0, 1.0]) parc = pytree_element_mul( pytree_normalized(pytree_sub(point_on_plane, tmp)), 1.0, # inv_vec[(counter % 2)], ) point_on_plane_2 = tmp + sphere_radius * parc print("point on plane ", point_on_plane_2) new_sample = sample_unravel(point_on_plane_2) state_stack = {} state_stack.update({"state": new_sample[0]}) state_stack.update({"bparam": new_sample[1]}) _parc_vec = pytree_sub(state_stack, _state_secant_c2) return _parc_vec, state_stack def _evaluate_perturb(self): """Evaluate weather the perturbed vector is orthogonal to secant vector""" dot = pytree_dot( pytree_normalized(self._parc_vec), pytree_normalized(self._state_secant_vector), ) if math.isclose(dot, 0.0, abs_tol=0.25): print(f"Perturb was near arc-plane. {dot}") self._state = self.state_stack["state"] self._bparam = self.state_stack["bparam"] else: print(f"Perturb was not on arc-plane.{dot}") def correction_step(self) -> Tuple: """Given the current state optimize to the correct state. Returns: (state: problem parameters, bparam: continuation parameter) Tuple """ _, key = random.split(random.PRNGKey(self.key_state + npr.randint(1, 100))) del _ quality = 1.0 N_opt = 10 stop = False corrector_omega = 1.0 # bparam_grads = pytree_zeros_like(self._bparam) print("the radius", self.sphere_radius) self._parc_vec, self.state_stack = self._perform_perturb_by_projection( self._state_secant_vector, self._state_secant_c2, key, self.pred_prev_state, self._state, self._bparam, self.sphere_radius, ) if self.hparams["_evaluate_perturb"]: self._evaluate_perturb() # does every time for j in range(self.descent_period): for b_j in range(self.num_batches): if self.hparams["meta"]["dataset"] == "mnist": # TODO: make it generic batch_data = next(self.data_loader) else: batch_data = None # grads = self.compute_grad_fn(self._state, self._bparam, batch_data) # self._state = self.opt.update_params(self._state, grads[0]) state_grads, bparam_grads = self.compute_min_grad_fn( self._state, self._bparam, self._lagrange_multiplier, self._state_secant_c2, self._state_secant_vector, batch_data, self.delta_s, ) if self.hparams["adaptive"]: self.opt.lr = self.exp_decay(j, self.hparams["natural_lr"]) quality = l2_norm(state_grads) # +l2_norm(bparam_grads) if quality > self.hparams["quality_thresh"]: pass # print(f"quality {quality}, {self.opt.lr}, {bparam_grads} ,{j}") else: if N_opt > (j + 1): # To get around folds slowly corrector_omega = min(N_opt / (j + 1), 2.0) else: corrector_omega = max(N_opt / (j + 1), 0.5) stop = True print(f"quality {quality} stopping at , {j}th step") state_grads = clip_grads(state_grads, self.hparams["max_clip_grad"]) bparam_grads = clip_grads( bparam_grads, self.hparams["max_clip_grad"] ) self._bparam = self.opt.update_params(self._bparam, bparam_grads, j) self._state = self.opt.update_params(self._state, state_grads, j) if stop: break if stop: break value = self.value_fn( self._state, self._bparam, batch_data ) # Todo: why only final batch data return self._state, self._bparam, quality, value, corrector_omega
class UnconstrainedCorrector(Corrector): """Minimize the objective using gradient based method.""" def __init__(self, objective, concat_states, grad_fn, value_fn, accuracy_fn, hparams, dataset_tuple): self.concat_states = concat_states self._state = None self._bparam = None self.opt = OptimizerCreator( opt_string=hparams["meta"]["optimizer"], learning_rate=hparams["natural_lr"]).get_optimizer() self.objective = objective self.accuracy_fn = accuracy_fn self.warmup_period = hparams["warmup_period"] self.hparams = hparams self.grad_fn = grad_fn self.value_fn = value_fn self._assign_states() if hparams["meta"]["dataset"] == "mnist": (self.train_images, self.train_labels, self.test_images, self.test_labels) = dataset_tuple if hparams["continuation_config"] == 'data': # data continuation self.data_loader = iter( get_mnist_batch_alter(self.train_images, self.train_labels, self.test_images, self.test_labels, alter=self._bparam, batch_size=hparams["batch_size"], resize=hparams["resize_to_small"], filter=hparams["filter"])) else: # model continuation self.data_loader = iter( get_mnist_data(batch_size=hparams["batch_size"], resize=hparams["resize_to_small"], filter=hparams["filter"]) # get_preload_mnist_data(self.train_images, ## TODO: better way to prefetch mnist # self.train_labels, # self.test_images, # self.test_labels, # batch_size = hparams["batch_size"], # resize = hparams["resize_to_small"], # filter = hparams["filter"]) ) self.num_batches = meta_mnist(hparams["batch_size"], hparams["filter"])["num_batches"] else: self.data_loader = None self.num_batches = 1 def _assign_states(self): self._state, self._bparam = self.concat_states def correction_step(self) -> Tuple: """Given the current state optimize to the correct state. Returns: (state: problem parameters, bparam: continuation parameter) Tuple """ quality = 1.0 ma_loss = [] stop = False print("learn_rate", self.opt.lr) for k in range(self.warmup_period): for b_j in range(self.num_batches): batch = next(self.data_loader) grads = self.grad_fn(self._state, self._bparam, batch) self._state = self.opt.update_params(self._state, grads[0]) quality = l2_norm(grads) value = self.value_fn(self._state, self._bparam, batch) ma_loss.append(value) self.opt.lr = exp_decay(k, self.hparams["natural_lr"]) if self.hparams["local_test_measure"] == "norm_gradients": if quality > self.hparams["quality_thresh"]: pass print(f"quality {quality}, {self.opt.lr} ,{k}") else: stop = True print(f"quality {quality} stopping at , {k}th step") else: if len(ma_loss) >= 20: tmp_means = running_mean(ma_loss, 10) if math.isclose( tmp_means[-1], tmp_means[-2], abs_tol=self.hparams["loss_tol"], ): print(f"stopping at , {k}th step") stop = True if stop: print("breaking") break val_loss = self.value_fn(self._state, self._bparam, (self.test_images, self.test_labels)) val_acc = self.accuracy_fn(self._state, self._bparam, (self.test_images, self.test_labels)) return self._state, self._bparam, quality, value, val_loss, val_acc