def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): data = data_cls(training=should_train) model = C(model_cls()) optimizer_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) if should_train: self.train() else: self.eval() result_dict = ResultDict() unroll_losses = 0 update = C(torch.zeros(model.params.size().flat)) iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') batch_arg = BatchManagerArgument( params=model.params, states=StatesSlicingWrapper(optimizer_states), updates=update, ) for iteration in iter_pbar: data_ = data.sample() loss = model(*data_) unroll_losses += loss model_dt = C(model_cls(params=batch_arg.params.detach())) model_dt_loss = model_dt(*data_) assert loss == model_dt_loss model_dt_loss.backward() assert model_dt.params.flat.grad is not None grad = model_dt.params.flat.grad batch_arg.update( BatchManagerArgument(grad=grad).immutable().volatile()) iter_pbar.set_description(f'optim_iteration[loss:{loss}]') ########################################################################## batch_manager = OptimizerBatchManager(batch_arg=batch_arg) for get, set in batch_manager.iterator(): updates, new_states = self(get.grad.detach(), get.states) set.states(new_states) set.params(get.params + updates * out_mul) if not should_train: set.updates(updates) ########################################################################## batch_arg = batch_manager.batch_arg_to_set if should_train and iteration % unroll == 0: meta_optimizer.zero_grad() unroll_losses.backward() meta_optimizer.step() if not should_train or iteration % unroll == 0: unroll_losses = 0 batch_arg.detach_() # model.params = model.params.detach() # optimizer_states = optimizer_states.detach() model = C(model_cls(params=batch_arg.params)) result_dict.append(loss=loss) if not should_train: result_dict.append(walltime=iter_watch.touch(), **self.params_tracker( grad=grad, update=batch_arg.updates, )) return result_dict
def meta_optimize(self, meta_optimizer, data, model_cls, optim_it, unroll, out_mul, tf_writer=None, mode='train'): assert mode in ['train', 'valid', 'test'] if mode == 'train': self.train() else: self.eval() result_dict = ResultDict() unroll_losses = 0 walltime = 0 params = C(model_cls()).params # lr_sgd = 0.2 # lr_adam = 0.001 # sgd = torch.optim.SGD(model.parameters(), lr=lr_sgd) # adam = torch.optim.Adam(model.parameters()) # adagrad = torch.optim.Adagrad(model.parameters()) iter_pbar = tqdm(range(1, optim_it + 1), 'inner_train') iter_watch = utils.StopWatch('inner_train') lr = 0.1 topk = False n_sample = 10 topk_decay = True recompute_grad = False self.topk_ratio = 0.5 # self.topk_ratio = 1.0 set_size = {'layer_0': 500, 'layer_1': 10} # NOTE: do it smarter for iteration in iter_pbar: # # conventional optimizers with API # model.zero_grad() # loss = model(*inner_data.load()) # loss.backward() # adam.step() # loss_detached = loss iter_watch.touch() model_train = C(model_cls(params=params.detach())) train_nll = model_train(*data['in_train'].load()) train_nll.backward() params = model_train.params.detach() grad = model_train.params.grad.detach() act = model_train.activations.detach() # conventional optimizers with manual SGD params_good = params + grad * (-lr) grad_good = grad best_loss_sparse = 999999 for i in range(n_sample): mask_gen = MaskGenerator.topk if topk else MaskGenerator.randk mask = mask_gen(grad=grad, set_size=set_size, topk=self.topk_ratio) if recompute_grad: # recompute gradients again using prunned network params_sparse = params.prune(mask) model_sparse = C(model_cls(params=params_sparse.detach())) loss_sparse = model_sparse(*data['in_train'].load()) loss_sparse.backward() grad_sparse = model_sparse.params.grad else: # use full gradients but sparsified params_sparse = params.prune(mask) grad_sparse = grad.prune(mask) params_sparse_ = params_sparse + grad_sparse * (-lr) # SGD model_sparse = C(model_cls(params=params_sparse_.detach())) loss_sparse = model_sparse(*data['in_train'].load()) if loss_sparse < best_loss_sparse: best_loss_sparse = loss_sparse best_params_sparse = params_sparse_ best_grads = grad_sparse best_mask = mask # below lines are just for evaluation purpose (excluded from time cost) walltime += iter_watch.touch('interval') # decayu topk_ratio if topk_decay: self.topk_ratio *= 0.999 # update! params = params.sparse_update(best_mask, best_grads * (-lr)) # generalization performance model_test = C(model_cls(params=params.detach())) test_nll = model_test(*data['in_test'].load()) # result dict result = dict( train_nll=train_nll.tolist(), test_nll=test_nll.tolist(), sparse_0=self.topk_ratio, sparse_1=self.topk_ratio, ) result_dict.append(result) log_pbar(result, iter_pbar) # desc = [f'{k}: {v:5.5}' for k, v in result.items()] # iter_pbar.set_description(f"inner_train [ {' / '.join(desc)} ]") return result_dict
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): if should_train: self.train() else: self.eval() unroll = 1 target = data_cls(training=should_train) optimizee = C(model_cls()) optimizer_states = OptimizerStates.zero_states( self.n_layers, len(optimizee.params), self.hidden_sz) result_dict = ResultDict() # n_params = 0 # named_shape = {} # for name, p in optimizee.all_named_parameters(): # n_params += int(np.prod(p.size())) # named_shape[name] = p.size() # batch_size = 30000 # batch_sizes = [] # if batch_size > n_params: # batch_sizes = [n_params] # else: # batch_sizes = [batch_size for _ in range(n_params // batch_size)] # batch_sizes.append(n_params % batch_size) coordinate_tracker = ParamsIndexTracker( n_params=len(optimizee.params), n_tracks=30) # update_track_num = 10 # n_params = len(optimizee.params) # if update_track_num > n_params: # update_track_num = n_params # update_track_id = np.random.randint(0, n_params, update_track_num) # updates_tracks = [] # hidden_states = [ # C(torch.zeros(n_params, self.hidden_sz)) for _ in range(2)] # cell_states = [ # C(torch.zeros(n_params, self.hidden_sz)) for _ in range(2)] all_losses_ever = [] timestamps = [] if should_train: meta_optimizer.zero_grad() unroll_losses = None ############################################################################ iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') for iteration in iter_pbar: loss = optimizee(target) iter_pbar.set_description(f'optim_iteration[loss:{loss}]') if unroll_losses is None: unroll_losses = loss else: unroll_losses += loss result_dict.add('losses', loss) try: loss.backward(retain_graph=should_train) except Exception as ex: print(ex) import pdb pdb.set_trace() # all_losses_ever.append(loss.data.cpu().numpy()) # hidden_states2 = [ # C(torch.zeros(n_params, self.hidden_sz)) for _ in range(2)] # cell_states2 = [ # C(torch.zeros(n_params, self.hidden_sz)) for _ in range(2)] ########################################################################## # params_flat = [] # grads_flat = [] # # for name, p in optimizee.all_named_parameters(): # named_shape[name] = p.size() # params_flat.append(p.view(-1, 1)) # grads_flat.append(C.detach(p.grad.view(-1, 1))) # params_flat = torch.cat(params_flat, 0) # grads_flat = torch.cat(grads_flat, 0) # # offset = 0 # params_batches = [] # grads_batches = [] # for i in range(len(batch_sizes)): # params_batches.append(params_flat[offset:offset + batch_sizes[i], :]) # grads_batches.append(grads_flat[offset:offset + batch_sizes[i], :]) # offset += batch_sizes[i] # assert(offset == params_flat.size(0)) # batches = zip(batch_sizes, params_batches, grads_batches) ########################################################################## offset = 0 result_params_flat = [] result_params = {} updates_track = [] batch_manager = OptimizerBatchManager( params=optimizee.params, states=optimizer_states, getters=['params', 'states'], setters=['params', 'states', 'updates'], batch_size=30000) # for batch_sz, params, grads in batches: for get, set in batch_manager.iterator(): # We do this so the gradients are disconnected from the graph # but we still get gradients from the rest updates, new_states = self( get.params.grad.detach(), get.states.clone()) set.states(new_states.clone()) set.params(get.params.clone() + updates * out_mul) set.updates(updates) # grads, # [h[offset:offset + batch_sz] for h in hidden_states], # [c[offset:offset + batch_sz] for c in cell_states] # ) # for i in range(len(new_hidden)): # hidden_states2[i][offset:offset + batch_sz] = new_hidden[i] # cell_states2[i][offset:offset + batch_sz] = new_cell[i] # result_params_flat.append(params + updates * out_mul) # updates_track.append(updates) # offset += batch_sz ########################################################################## # result_params_flat = torch.cat(result_params_flat, 0) # updates_track = torch.cat(updates_track, 0) # updates_track = updates_track.data.cpu().numpy() # updates_track = np.take(updates_track, update_track_id) # updates_tracks.append(updates_track) # offset = 0 # for name, shape in named_shape.items(): # n_unit_params = int(np.prod(shape)) # result_params[name] = result_params_flat[ # offset:offset + n_unit_params, :].view(shape) # result_params[name].retain_grad() # offset += n_unit_params ########################################################################## # NOTE: optimizee.states.states <-- looks a bit weird if iteration % unroll == 0: if should_train: meta_optimizer.zero_grad() unroll_losses.backward() meta_optimizer.step() unroll_losses = None optimizee.params.detach_() optimizer_states = optimizer_states.detach() # detached_params = {k: C.detach(v) for k, v in result_params.items()} # optimizee = C(model_cls(**detached_params)) # hidden_states = [C.detach(s) for s in hidden_states2] # cell_states = [C.detach(s) for s in cell_states2] # # else: # optimizee = C(model_cls(**result_params)) # assert len(list(optimizee.all_named_parameters())) # hidden_states = hidden_states2 # cell_states = cell_states2 optimizee = C(model_cls(optimizee.params)) # timestamps.append(iter_watch.touch()) result_dict.add('updates', coordinate_tracker(batch_manager.updates)) result_dict.add('timestamps', iter_watch.touch()) return all_losses_ever, updates_tracks, timestamps
def meta_optimize(self, meta_optimizer, data, model_cls, optim_it, unroll, out_mul, mode): assert mode in ['train', 'valid', 'test'] if mode == 'train': self.train() # data.new_train_data() inner_data = data.loaders['inner_train'] outer_data = data.loaders['inner_valid'] # inner_data = data.loaders['train'] # outer_data = inner_data drop_mode = 'soft_drop' elif mode == 'valid': self.eval() inner_data = data.loaders['valid'] outer_data = None drop_mode = 'hard_drop' elif mode == 'eval': self.eval() inner_data = data.loaders['test'] outer_data = None drop_mode = 'hard_drop' # data = dataset(mode=mode) model = C(model_cls(sb_mode=self.sb_mode)) model(*data.pseudo_sample()) # layer_sz = sum([v for v in model.params.size().unflat(1, 'mat').values()]) mask_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.activations.mean(0)), self.hidden_sz))) update_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) result_dict = ResultDict() unroll_losses = 0 walltime = 0 update = C(torch.zeros(model.params.size().flat())) batch_arg = BatchManagerArgument( params=model.params, states=StatesSlicingWrapper(update_states), updates=model.params.new_zeros(model.params.size()), ) params = model.params lr_sgd = 0.2 lr_adam = 0.001 sgd = torch.optim.SGD(model.parameters(), lr=lr_sgd) adam = torch.optim.Adam(model.parameters()) adagrad = torch.optim.Adagrad(model.parameters()) # print('###', lr_adam, '####') iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') bp_watch = utils.StopWatch('bp') loss_decay = 1.0 dist_list = [] dense_rank = 0 sparse_rank = 0 dense_rank_list = [] sparse_rank_list = [] for iteration in iter_pbar: iter_watch.touch() # # conventional optimizers with API # model.zero_grad() # loss = model(*inner_data.load()) # loss.backward() # adam.step() # loss_detached = loss model_detached = C(model_cls(params=params.detach())) loss_detached = model_detached(*inner_data.load()) loss_detached.backward() # # conventional optimizers with manual SGD # params = model_detached.params + model_detached.params.grad * (-0.2) # iter_pbar.set_description(f'optim_iteration[loss:{loss_detached}]') masks = [] # sparse_old_loss = [] # sparse_new_loss = [] sparse_loss = [] dense_loss = [] candidates = [] mode = 2 n_sample = 10 if mode == 1: lr = 1.0 elif mode == 2: lr = 0.2 else: raise Exception() for i in range(n_sample): if i == 9: mask = MaskGenerator.topk( model_detached.activations.grad.detach()) else: mask = MaskGenerator.randk( model_detached.activations.grad.detach()) if model == 1: # recompute gradients again using prunned network sparse_params = model_detached.params.prune(mask) model_prunned = C(model_cls(params=sparse_params.detach())) loss_prunned = model_prunned(*inner_data.load()) loss_prunned.backward() grads = model_prunned.params.grad elif mode == 2: # use full gradients but sparsified sparse_params = model_detached.params.prune(mask) grads = model_detached.params.grad.prune(mask) else: raise Exception('unknown model.') sparse_params = sparse_params + grads * (-lr) model_prunned = C(model_cls(params=sparse_params.detach())) loss_prunned = model_prunned(*inner_data.load()) # params = model_detached.params params = model_detached.params.sparse_update( mask, grads * (-lr)) candidates.append(params) model_dense = C(model_cls(params=params)) loss_dense = model_dense(*inner_data.load()) sparse_loss.append(loss_prunned) dense_loss.append(loss_dense) sparse_loss = torch.stack(sparse_loss, 0) dense_loss = torch.stack(dense_loss, 0) min_id = ( sparse_loss.min() == sparse_loss).nonzero().squeeze().tolist() params = candidates[min_id] sparse_order = sparse_loss.sort()[1].float() dense_order = dense_loss.sort()[1].float() dist = (sparse_order - dense_order).abs().sum() dist_list.append(dist) dense_rank += (dense_order == 9).nonzero().squeeze().tolist() sparse_rank += (sparse_order == 9).nonzero().squeeze().tolist() dense_rank_list.append( (dense_order == 9).nonzero().squeeze().tolist()) sparse_rank_list.append( (sparse_order == 9).nonzero().squeeze().tolist()) iter_pbar.set_description( f'optim_iteration[dense_loss:{loss_dense.tolist():5.5}/sparse_loss:{loss_prunned.tolist():5.5}]' ) # iter_pbar.set_description(f'optim_iteration[loss:{loss_dense.tolist()}/dist:{dist.tolist()}]') # torch.optim.SGD(sparse_params, lr=0.1).step() result_dict.append(loss=loss_detached) if not mode == 'train': result_dict.append( walltime=walltime, # **self.params_tracker( # grad=grad, # update=batch_arg.updates, # ) ) dist_mean = torch.stack(dist_list, 0).mean() import pdb pdb.set_trace() return result_dict
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): data = data_cls(training=should_train) model = C(model_cls()) g_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) u_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) if should_train: self.train() else: self.eval() grad_pred_tracker = ParamsIndexTracker(name='grad_pred', n_tracks=10, n_params=len(model.params)) grad_real_tracker = grad_pred_tracker.clone_new_name('grad_real') update_tracker = grad_pred_tracker.clone_new_name('update') result_dict = ResultDict() model_dt = None model_losses = 0 grad_losses = 0 train_with_lazy_tf = False eval_test_grad_loss = True grad_prev = C(torch.zeros(model.params.size().flat())) update_prev = C(torch.zeros(model.params.size().flat())) #grad_real_prev = None grad_real_cur = None grad_real_prev = C(torch.zeros(model.params.size().flat())) iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') """grad_real_prev: teacher-forcing current input of RNN. required during all of training steps + every 'k' test steps. grad_real_cur: MSE loss target for current output of RNN. required during all the training steps only. """ batch_arg = BatchManagerArgument( params=model.params, grad_prev=grad_prev, update_prev=update_prev, g_states=StatesSlicingWrapper(g_states), u_states=StatesSlicingWrapper(u_states), mse_loss=C(torch.zeros(1)), #updates=OptimizerBatchManager.placeholder(model.params.flat.size()), ) for iteration in iter_pbar: inp, out = data.sample() model_loss = model(inp, out) model_losses += model_loss # if grad_real_cur is not None: # if iteration % self.k == 0 # grad_cur_prev = grad_real_cur # back up previous grad if (iteration + 1) % self.k == 0 or should_train or eval_test_grad_loss: # get grad every step while training # get grad one step earlier than teacher-forcing steps while testing model_dt = C(model_cls(params=model.params.detach())) model_loss_dt = model_dt(inp, out) # model.params.register_grad_cut_hook_() model_loss_dt.backward() # model.params.remove_grad_cut_hook_() grad_real_cur = model_dt.params.flat.grad assert grad_real_cur is not None del model_dt, model_loss_dt if should_train or eval_test_grad_loss: # mse loss target # grad_real_cur is not used as a mse target while testing assert grad_real_cur is not None batch_arg.update( BatchManagerArgument( grad_real_cur=grad_real_cur).immutable().volatile()) if iteration % self.k == 0 or (should_train and not train_with_lazy_tf): # teacher-forcing assert grad_real_prev is not None batch_arg.update( BatchManagerArgument( grad_real_prev=grad_real_prev).immutable().volatile()) grad_real_prev = None ######################################################################## batch_manager = OptimizerBatchManager(batch_arg=batch_arg) for get, set in batch_manager.iterator(): if iteration % self.k == 0 or (should_train and not train_with_lazy_tf): grad_prev = get.grad_real_prev.detach() # teacher-forcing else: grad_prev = get.grad_prev.detach() # free-running update_prev = get.update_prev.detach() grad_cur, update_cur, g_states, u_states = self( grad_prev, update_prev, get.g_states, get.u_states) set.grad_prev(grad_cur) set.update_prev(update_cur) set.g_states(g_states) set.u_states(u_states) set.params(get.params + update_cur * out_mul) if should_train or eval_test_grad_loss: set.mse_loss( self.mse_loss(grad_cur, get.grad_real_cur.detach())) ########################################################################## batch_arg = batch_manager.batch_arg_to_set grad_loss = batch_arg.mse_loss * 100 grad_losses += grad_loss #to_float = lambda x: float(x.data.cpu().numpy()) iter_pbar.set_description( 'optim_iteration[loss(model):{:10.6f}/(grad):{:10.6f}]'.format( model_loss.tolist(), grad_loss.tolist()[0])) if should_train and iteration % unroll == 0: meta_optimizer.zero_grad() losses = model_losses + grad_losses losses.backward() meta_optimizer.step() if not should_train or iteration % unroll == 0: model_losses = 0 grad_losses = 0 batch_arg.detach_() # batch_arg.params.detach_() # batch_arg.g_states.detach_() # batch_arg.u_states.detach_() model = C(model_cls(params=batch_arg.params)) result_dict.append( step_num=iteration, loss=model_loss, grad_loss=grad_loss, ) if not should_train: result_dict.append( walltime=iter_watch.touch(), **grad_pred_tracker.track_num, **grad_pred_tracker(batch_arg.grad_prev), **grad_real_tracker(grad_real_cur), **update_tracker(batch_arg.update_prev), ) if (iteration + 1) % self.k == 0 or should_train or eval_test_grad_loss: grad_real_prev = grad_real_cur grad_real_cur = None return result_dict
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): data = data_cls(training=should_train) model = C(model_cls()) g_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) if should_train: self.train() else: self.eval() result_dict = ResultDict() model_dt = None grad_losses = 0 train_with_lazy_tf = True eval_test_grad_loss = True grad_prev = C(torch.zeros(model.params.size().flat())) update_prev = C(torch.zeros(model.params.size().flat())) #grad_real_prev = None grad_real_cur = None grad_real_prev = C(torch.zeros(model.params.size().flat())) iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') """grad_real_prev: teacher-forcing current input of RNN. required during all of training steps + every 'k' test steps. grad_real_cur: MSE loss target for current output of RNN. required during all the training steps only. """ batch_arg = BatchManagerArgument( params=model.params, grad_prev=grad_prev, update_prev=update_prev, g_states=StatesSlicingWrapper(g_states), mse_loss=C(torch.zeros(1)), sigma=C(torch.zeros(model.params.size().flat())), #updates=OptimizerBatchManager.placeholder(model.params.flat.size()), ) for iteration in iter_pbar: data_ = data.sample() # if grad_real_cur is not None: # if iteration % self.k == 0 # grad_cur_prev = grad_real_cur # back up previous grad if (iteration + 1) % self.k == 0 or should_train or eval_test_grad_loss: # get grad every step while training # get grad one step earlier than teacher-forcing steps while testing model_dt = C(model_cls(params=model.params.detach())) model_loss_dt = model_dt(*data_) # model.params.register_grad_cut_hook_() model_loss_dt.backward() # model.params.remove_grad_cut_hook_() grad_real_cur = model_dt.params.flat.grad assert grad_real_cur is not None if should_train or eval_test_grad_loss: # mse loss target # grad_real_cur is not used as a mse target while testing assert grad_real_cur is not None batch_arg.update( BatchManagerArgument( grad_real_cur=grad_real_cur).immutable()) if iteration % self.k == 0 or (should_train and not train_with_lazy_tf): # teacher-forcing assert grad_real_prev is not None batch_arg.update( BatchManagerArgument( grad_real_prev=grad_real_prev).immutable().volatile()) grad_real_prev = None ######################################################################## batch_manager = OptimizerBatchManager(batch_arg=batch_arg) for get, set in batch_manager.iterator(): if iteration % self.k == 0 or (should_train and not train_with_lazy_tf): grad_prev = get.grad_real_prev # teacher-forcing else: grad_prev = get.grad_prev # free-running grad_cur, g_states = self(grad_prev.detach(), get.update_prev.detach(), get.params.detach(), get.g_states) set.grad_prev(grad_cur) set.g_states(g_states) set.params(get.params - 1.0 * grad_cur.detach()) set.sigma(self.sigma) if should_train or eval_test_grad_loss: set.mse_loss( self.mse_loss(grad_cur, get.grad_real_cur.detach())) ########################################################################## batch_arg = batch_manager.batch_arg_to_set grad_real = batch_manager.batch_arg_to_get.grad_real_cur.detach() cosim_loss = nn.functional.cosine_similarity(batch_arg.grad_prev, grad_real, dim=0) grad_loss = batch_arg.mse_loss * 100 cosim_loss = -torch.log((cosim_loss + 1) * 0.5) * 100 grad_losses += (grad_loss + cosim_loss) #to_float = lambda x: float(x.data.cpu().numpy()) sigma = batch_arg.sigma.detach().mean() iter_pbar.set_description( 'optim_iteration[loss(model):{:10.6f}/(mse):{:10.6f}/(cosim):{:10.6f}/(sigma):{:10.6f}]' ''.format(model_loss_dt.tolist(), grad_loss.tolist()[0], cosim_loss.tolist()[0], sigma.tolist())) if should_train and iteration % unroll == 0: w_decay_loss = 0 for p in self.parameters(): w_decay_loss += torch.sum(p * p) w_decay_loss = (w_decay_loss * 0.5) * 1e-4 total_loss = grad_losses + w_decay_loss meta_optimizer.zero_grad() total_loss.backward() meta_optimizer.step() if not should_train or iteration % unroll == 0: model_losses = 0 grad_losses = 0 batch_arg.detach_() # batch_arg.params.detach_() # batch_arg.g_states.detach_() # batch_arg.u_states.detach_() result_dict.append( step_num=iteration, loss=model_loss_dt, grad_loss=grad_loss, ) if not should_train: result_dict.append(walltime=iter_watch.touch(), **self.params_tracker( grad_real=grad_real_cur, grad_pred=batch_arg.grad_prev, update=batch_arg.update_prev, sigma=batch_arg.sigma, )) if (iteration + 1) % self.k == 0 or should_train or eval_test_grad_loss: grad_real_prev = grad_real_cur grad_real_cur = None return result_dict
def meta_optimize(self, meta_optimizer, data, model_cls, optim_it, unroll, out_mul, mode): assert mode in ['train', 'valid', 'test'] if mode == 'train': self.train() # data.new_train_data() inner_data = data.loaders['inner_train'] outer_data = data.loaders['inner_valid'] # inner_data = data.loaders['train'] # outer_data = inner_data drop_mode = 'soft_drop' elif mode == 'valid': self.eval() inner_data = data.loaders['valid'] outer_data = None drop_mode = 'hard_drop' elif mode == 'eval': self.eval() inner_data = data.loaders['test'] outer_data = None drop_mode = 'hard_drop' # data = dataset(mode=mode) model = C(model_cls(sb_mode=self.sb_mode)) model(*data.pseudo_sample()) # layer_sz = sum([v for v in model.params.size().unflat(1, 'mat').values()]) mask_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.activations.mean(0)), self.hidden_sz))) update_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) result_dict = ResultDict() unroll_losses = 0 walltime = 0 update = C(torch.zeros(model.params.size().flat())) batch_arg = BatchManagerArgument( params=model.params, states=StatesSlicingWrapper(update_states), updates=model.params.new_zeros(model.params.size()), ) iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') bp_watch = utils.StopWatch('bp') loss_decay = 1.0 for iteration in iter_pbar: iter_watch.touch() model_detached = C( model_cls(params=batch_arg.params.detach(), sb_mode=self.sb_mode)) loss_detached = model_detached(*inner_data.load()) loss_detached.backward() # import pdb; pdb.set_trace() iter_pbar.set_description(f'optim_iteration[loss:{loss_detached}]') # np.set_printoptions(suppress=True, precision=3, linewidth=200, edgeitems=20) if debug_sigint.signal_on: debug = iteration == 1 or iteration % 10 == 0 # debug = True else: debug = False backprop_mask, mask_states, sparsity_loss = self.mask_generator( activations=model_detached.activations.grad.detach(), debug=debug, states=mask_states) # if debug_sigstp.signal_on and (iteration == 1 or iteration % 10 == 0): # mask_list.append(backprop_mask) # import pdb; pdb.set_trace() # min = 0.001 # max = 0.01 unroll_losses += sparsity_loss * 0.01 * \ loss_decay # * ((max - min) * lambda_ + min) # import pdb; pdb.set_trace() # model_detached.apply_backprop_mask(backprop_mask.unflat, drop_mode) # loss_detached.backward() assert model_detached.params.grad is not None if mode == 'train': model = C( model_cls(params=batch_arg.params, sb_mode=self.sb_mode)) loss = model(*outer_data.load()) # assert loss == loss_detached if iteration == 1: initial_loss = loss loss_decay = 1.0 else: loss_decay = (loss / initial_loss).detach() unroll_losses += loss # model_detached.apply_backprop_mask(backprop_mask.unflat, drop_mode) #model_detached.params.apply_grad_mask(mask, 'soft_drop') coord_mask = backprop_mask.expand_as_params(model_detached.params) if drop_mode == 'soft_drop': # grad = model_detached.params.grad.mul(coord_mask).flat # model_detached.params = model_detached.params.mul(params_mask) index = None # updates, states = self.updater(grad, states, mask=coord_mask) # updates = updates * out_mul * coord_mask # import pdb; pdb.set_trace() elif drop_mode == 'hard_drop': index = (coord_mask.flat.squeeze() > 0.5).nonzero().squeeze() grad = model_detached.params.grad.flat batch_arg.update(BatchManagerArgument(grad=grad, mask=coord_mask)) ########################################################################## indexer = DefaultIndexer(input=grad, index=index, shuffle=False) batch_manager = OptimizerBatchManager(batch_arg=batch_arg, indexer=indexer) for get, set in batch_manager.iterator(): inp = get.grad().detach() #inp = [get.grad().detach(), get.tag().detach()] if drop_mode == 'soft_drop': updates, new_states = self.updater( inp, get.states()) # , mask=get.mask()) # if iteration == 100: # import pdb; pdb.set_trace() # aaa = get.states() *2 updates = updates * out_mul * get.mask() set.states(get.states() * (1 - get.mask()) + new_states * get.mask()) # set.states(new_states) elif drop_mode == 'hard_drop': updates, new_states = self.updater(inp, get.states()) updates = updates * out_mul set.states(new_states) else: raise RuntimeError('Unknown drop mode!') set.params(get.params() + updates) # if not mode != 'train': # set.updates(updates) ########################################################################## batch_arg = batch_manager.batch_arg_to_set # updates = model.params.new_from_flat(updates) if mode == 'train' and iteration % unroll == 0: meta_optimizer.zero_grad() unroll_losses.backward() nn.utils.clip_grad_norm_(self.parameters(), 5) meta_optimizer.step() unroll_losses = 0 # if iteration == 1 or iteration % 10 == 0: # import pdb; pdb.set_trace() if not mode == 'train' or iteration % unroll == 0: mask_states.detach_() batch_arg.detach_() # model.params = model.params.detach() # update_states = update_states.detach() walltime += iter_watch.touch('interval') result_dict.append(loss=loss_detached) if not mode == 'train': result_dict.append( walltime=walltime, # **self.params_tracker( # grad=grad, # update=batch_arg.updates, # ) ) return result_dict
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): data = data_cls(training=should_train) model = C(model_cls()) optimizer_states = C(OptimizerStates.initial_zeros( size=(self.n_layers, len(model.params), self.hidden_sz), rnn_cell=self.rnn_cell)) if should_train: self.train() else: self.eval() result_dict = ResultDict() walltime = 0 unroll_losses = 0 update = C(torch.zeros(model.params.size().flat())) batch_arg = BatchManagerArgument( params=model.params, states=StatesSlicingWrapper(optimizer_states), updates=update, ) iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') for iteration in iter_pbar: iter_watch.touch() # entire_loop.touch() # interval.touch() data_ = data.sample() if should_train: model = C(model_cls(params=batch_arg.params)) loss = model(*data_) unroll_losses += loss # print("\n[1] : "+str(interval.touch('interval', False))) # model_ff.touch() model_detached = C(model_cls(params=batch_arg.params.detach())) loss_detached = model_detached(*data_) # model_ff.touch('interval', True) if should_train: assert loss == loss_detached # model_bp.touch() loss_detached.backward() # model_bp.touch('interval', True) # interval.touch() assert model_detached.params.flat.grad is not None grad = model_detached.params.flat.grad batch_arg.update(BatchManagerArgument(grad=grad).immutable().volatile()) iter_pbar.set_description(f'optim_iteration[loss:{loss_detached}]') # print("\n[2-1] : "+str(interval.touch('interval', False))) ########################################################################## # indexer = DefaultIndexer(input=grad, shuffle=False) indexer = TopKIndexer(input=grad, k_ratio=0.1, random_k_mode=False)# bound=model_dt.params.size().bound_layerwise()) #indexer = DefaultIndexer(input=grad) batch_manager = OptimizerBatchManager( batch_arg=batch_arg, indexer=indexer) # print("\n[2-2] : "+str(interval.touch('interval', False))) # optim_ff.touch() # optim_ff_1.touch() for get, set in batch_manager.iterator(): # optim_ff_1.touch('interval', True) # optim_ff_2.touch() # optim_ff_2.touch('interval', True) # optim_ff_3.touch() updates, new_states = self(get.grad().detach(), get.states()) # optim_ff_3.touch('interval', True) # optim_ff_4.touch() updates = updates * out_mul set.states(new_states) set.params(get.params() + updates) if not should_train: set.updates(updates) # optim_ff_4.touch('interval', True) # optim_ff.touch('interval', True) ########################################################################## # interval.touch() batch_arg = batch_manager.batch_arg_to_set # print("\n[3] : "+str(interval.touch('interval', False))) if should_train and iteration % unroll == 0: # optim_bp.touch() meta_optimizer.zero_grad() unroll_losses.backward() nn.utils.clip_grad_norm_(self.parameters(), 10) meta_optimizer.step() unroll_losses = 0 # optim_bp.touch('interval', True) # interval.touch() if not should_train or iteration % unroll == 0: batch_arg.detach_() # print("\n[4] : "+str(interval.touch('interval', False))) # interval.touch() # model.params = model.params.detach() # optimizer_states = optimizer_states.detach() walltime += iter_watch.touch('interval') result_dict.append(loss=loss_detached) # print("\n[5] : "+str(interval.touch('interval', False))) if not should_train: result_dict.append( walltime=iter_watch.touch(), **self.params_tracker( grad=grad, update=batch_arg.updates, ) ) # entire_loop.touch('interval', True) return result_dict
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): target = data_cls(training=should_train) optimizee = C(model_cls()) optimizer_states = OptimizerStates.zero_states( self.n_layers, len(optimizee.params), self.hidden_sz) if should_train: self.train() else: self.eval() coordinate_tracker = ParamsIndexTracker( n_params=len(optimizee.params), n_tracks=30) result_dict = ResultDict() unroll_losses = 0 iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') for iteration in iter_pbar: loss = optimizee(target) unroll_losses += loss loss.backward(retain_graph=should_train) iter_pbar.set_description(f'optim_iteration[loss:{loss}]') unit_size = torch.Size([len(optimizee.params), 1]) grid_size = torch.Size([len(optimizee.params), self.n_coord]) batch_manager = OptimizerBatchManager( params=optimizee.params, states=StatesSlicingWrapper(optimizer_states), units=OptimizerBatchManager.placeholder(unit_size), grid_abs=OptimizerBatchManager.placeholder(grid_size), grid_rel=OptimizerBatchManager.placeholder(grid_size), batch_size=30000, ) ########################################################################## for get, set in batch_manager.iterator(): units, new_states = self.unit_maker( C.detach(get.params.grad), get.states.clone()) units = units * out_mul / self.n_coord / 2 grid_abs, grid_rel = self.grid_maker(units, get.params) set.states(new_states) set.units(units) set.grid_abs(grid_abs) set.grid_rel(grid_rel) ########################################################################## landscape = self.loss_observer(batch_manager.grid_abs, model_cls, target) updates = self.step_maker(batch_manager.grid_rel, landscape) optimizee.params.add_flat_(updates) if should_train and iteration % unroll == 0: meta_optimizer.zero_grad() unroll_losses.backward() meta_optimizer.step() if not should_train or iteration % unroll == 0: unroll_losses = 0 optimizee.params = optimizee.params.detach() optimizer_states = optimizer_states.detach() optimizee = C(model_cls(optimizee.params)) result_dict.append(loss=loss) if not should_train: result_dict.append( grid=coordinate_tracker(batch_manager.units), update=coordinate_tracker(updates), walltime=iter_watch.touch(), ) return result_dict
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): if should_train: self.train() else: self.eval() unroll = 1 target = data_cls(training=should_train) optimizee = C(model_cls()) n_params = 0 for p in optimizee.parameters(): n_params += int(np.prod(p.size())) update_track_num = 10 if update_track_num > n_params: update_track_num = n_params update_track_id = np.random.randint(0, n_params, update_track_num) updates_tracks = [] grids_tracks = [] hidden_states = [ C(torch.zeros(n_params, self.hidden_sz)) for _ in range(2) ] cell_states = [ C(torch.zeros(n_params, self.hidden_sz)) for _ in range(2) ] all_losses_ever = [] timestamps = [] if should_train: meta_optimizer.zero_grad() all_losses = None ############################################################################ iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') for iteration in iter_pbar: #watch.go('forward') loss = optimizee(target) #forward_time.append(watch.stop('forward')) iter_pbar.set_description(f'optim_iteration[loss:{loss}]') if all_losses is None: all_losses = loss else: all_losses += loss all_losses_ever.append(loss.data.cpu().numpy()) #watch.go('backward') loss.backward(retain_graph=should_train) #backward_time.append(watch.stop('backward')) offset = 0 navi_step_params = [dict() for _ in range(self.n_navi_step)] navi_step_optimizee = [] navi_step_loss = [] navi_step_delta = [] result_params = {} hidden_states2 = [ C(torch.zeros(n_params, self.hidden_sz)) for _ in range(2) ] cell_states2 = [ C(torch.zeros(n_params, self.hidden_sz)) for _ in range(2) ] ########################################################################## # torch.Parameter-wise linear navigation: get linear step points. for name, p in optimizee.all_named_parameters(): cur_sz = int(np.prod(p.size())) # We do this so the gradients are disconnected from the graph # but we still get gradients from the rest gradients = C.detach(p.grad.view(cur_sz, 1)) navi_step, new_hidden, new_cell = self.navi( gradients, [h[offset:offset + cur_sz] for h in hidden_states], [c[offset:offset + cur_sz] for c in cell_states], ) for i in range(len(new_hidden)): hidden_states2[i][offset:offset + cur_sz] = new_hidden[i] cell_states2[i][offset:offset + cur_sz] = new_cell[i] offset += cur_sz for i in range(self.n_navi_step): navi_delta = navi_step[:, i].view( *p.size()) * out_mul / self.n_navi_step / 2 if i == 0: navi_step_delta.append(navi_delta) navi_step_params[i][name] = p + navi_delta navi_step_params[i][name].retain_grad() # non-leaf tensor ########################################################################## # Evalute loss at every linear navigation points # and choose a single point(proper step size in linear grid) between them. for i in range(self.n_navi_step): optimizee = C(model_cls(**navi_step_params[i])) assert len(list(optimizee.all_named_parameters())) navi_step_loss.append(optimizee(target)) navi_step_loss = torch.stack(navi_step_loss) min = torch.min(navi_step_loss) max = torch.max(navi_step_loss) navi_step_loss = (navi_step_loss - min) / (max - min + 1e-10) step_size = self.step(navi_step_loss) ########################################################################## result_params = {} updates_track = [] grids_track = [] pairs = zip(optimizee.all_named_parameters(), navi_step_delta) for (name, param), delta in pairs: update = delta * step_size grid = delta * self.n_navi_step if step_size > self.n_navi_step: import pdb pdb.set_trace() if step_size < 0: import pdb pdb.set_trace() result_params[name] = param + update #* out_mul result_params[name].retain_grad() updates_track.append(update) grids_track.append(grid) updates_track = torch.cat(updates_track, 0).data.cpu().numpy() updates_track = np.take(updates_track, update_track_id) updates_tracks.append(updates_track) grids_track = torch.cat(grids_track, 0).data.cpu().numpy() grids_tracks.append(grids_track) if iteration % unroll == 0: if should_train: meta_optimizer.zero_grad() all_losses.backward() meta_optimizer.step() all_losses = None detached_params = { k: C.detach(v) for k, v in result_params.items() } optimizee = C(model_cls(**detached_params)) hidden_states = [C.detach(s) for s in hidden_states2] cell_states = [C.detach(s) for s in cell_states2] else: optimizee = C(model_cls(**result_params)) assert len(list(optimizee.all_named_parameters())) hidden_states = hidden_states2 cell_states = cell_states2 timestamps.append(iter_watch.touch()) return all_losses_ever, updates_tracks, grids_tracks, timestamps
def meta_optimize(self, meta_optimizer, data, model_cls, optim_it, unroll, out_mul, mode): assert mode in ['train', 'valid', 'test'] if mode == 'train': self.train() # data.new_train_data() inner_data = data.loaders['inner_train'] outer_data = data.loaders['inner_valid'] # inner_data = data.loaders['train'] # outer_data = inner_data drop_mode = 'soft_drop' elif mode == 'valid': self.eval() inner_data = data.loaders['valid'] outer_data = None drop_mode = 'hard_drop' elif mode == 'test': self.eval() inner_data = data.loaders['test'] outer_data = None drop_mode = 'hard_drop' # data = dataset(mode=mode) model = C(model_cls(sb_mode=self.sb_mode)) model(*data.pseudo_sample()) params = model.params result_dict = ResultDict() unroll_losses = 0 walltime = 0 self.feature_gen.new() ################ mask_dict = ResultDict() analyze_mask = False sample_mask = False draw_loss = False ################ iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') res1 = None res2 = [] res3 = [] lamb = [] gamm_g = [] gamm_l1 = [] gamm_l2 = [] iters = [] for iteration in iter_pbar: iter_watch.touch() if debug_sigint.signal_on: debug_1 = iteration == 1 or iteration % 10 == 0 else: debug_1 = False if debug_sigstp.signal_on: debug_2 = True else: debug_2 = False model_detached = C(model_cls(params.detach())) inner_data_s = inner_data.load() loss_detached = model_detached(*inner_data_s) loss_detached.backward() g = model_detached.params.grad.flat.detach() w = model_detached.params.flat.detach() cand_params = [] cand_losses = [] best_loss = 9999999 best_params = None n_samples = 5 for m in range(1): feature = self.feature_gen(g, w, n=n_samples, m_update=(m == 0)) p_size = params.size().unflat() mask_gen_out = self.mask_gen(feature, p_size, n=n_samples) step_out = self.step_gen(feature, n=n_samples) # step & mask genration for i in range(n_samples): mask, _, kld = mask_gen_out[i] step = step_out[i] mask = ParamsFlattener(mask) mask_layout = mask.expand_as(params) step = params.new_from_flat(step) # prunning step = step * mask_layout params_ = params + step # cand_params.append(params_.flat) sparse_params = params_.prune(mask > 1e-6) if sparse_params.size().unflat()['mat_0'][1] == 0: continue if debug_2: import pdb pdb.set_trace() # cand_loss = model(*outer_data_s) sparse_model = C(model_cls(sparse_params.detach())) loss = sparse_model(*inner_data_s) try: if loss < best_loss: best_loss = loss best_params = params_ best_kld = kld except: best_params = params_ best_kld = kld if best_params is not None: params = best_params best_kld = 0 if mode == 'train': model = C(model_cls(params)) optim_loss = model(*outer_data.load()) if torch.isnan(optim_loss): import pdb pdb.set_trace() unroll_losses += optim_loss + best_kld / outer_data.full_size if mode == 'train' and iteration % unroll == 0: meta_optimizer.zero_grad() unroll_losses.backward() # import pdb; pdb.set_trace() nn.utils.clip_grad_value_(self.parameters(), 0.1) meta_optimizer.step() unroll_losses = 0 if not mode == 'train' or iteration % unroll == 0: # self.mask_gen.detach_lambdas_() params = params.detach_() # import pdb; pdb.set_trace() if params is None: import pdb pdb.set_trace() iter_pbar.set_description( f'optim_iteration' f'[optim_loss:{loss_detached.tolist():5.5}') # f' sparse_loss:{sparsity_loss.tolist():5.5}]') # iter_pbar.set_description(f'optim_iteration[loss:{loss_dense.tolist()}/dist:{dist.tolist()}]') # torch.optim.SGD(sparse_params, lr=0.1).step() walltime += iter_watch.touch('interval') ############################################################## text_dir = 'test/analyze_mask' result_dir = 'test/drawloss' sample_dir = 'test/mask_compare' iter_interval = 10 sample_num = 10000 if mode == 'test' and analyze_mask: analyzing_mask(self.mask_gen, layer_size, mode, iter, iter_interval, text_dir) if mode == 'test' and sample_mask: mask_result = sampling_mask(self.mask_gen, layer_size, model_train, params, sample_num, mode, iter, iter_interval, sample_dir) if mask_result is not None: mask_dict.append(mask_result) if mode == 'test' and draw_loss: plot_loss(model_cls=model_cls, model=model_train, params, input_data=data['in_train'].load(), dataset=data['in_train'], feature_gen=self.feature_gen, mask_gen=self.mask_gen, step_gen=self.step_gen, scale_way=None, xmin=-2.0, xmax=0.5, num_x=20, mode=mode, iteration=iter, iter_interval=iter_interval, loss_dir=result_dir) ############################################################## result_dict.append(loss=loss_detached) if not mode == 'train': result_dict.append( walltime=walltime, # **self.params_tracker( # grad=grad, # update=batch_arg.updates, # ) ) return result_dict
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): if should_train: self.train() else: self.eval() unroll = 1 target = data_cls(training=should_train) optimizee = C(model_cls()) n_params = 0 for p in optimizee.parameters(): n_params += int(np.prod(p.size())) update_track_num = 10 if update_track_num > n_params: update_track_num = n_params update_track_id = np.random.randint(0, n_params, update_track_num) updates_tracks = [] hidden_states = [ C(torch.zeros(n_params, self.hidden_sz)) for _ in range(2) ] cell_states = [ C(torch.zeros(n_params, self.hidden_sz)) for _ in range(2) ] all_losses_ever = [] timestamps = [] if should_train: meta_optimizer.zero_grad() all_losses = None ############################################################################ iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') for iteration in iter_pbar: loss = optimizee(target) iter_pbar.set_description(f'optim_iteration[loss:{loss}]') if all_losses is None: all_losses = loss else: all_losses += loss all_losses_ever.append(loss.data.cpu().numpy()) loss.backward(retain_graph=should_train) import pdb pdb.set_trace() offset = 0 result_params = {} updates_track = [] hidden_states2 = [ C(torch.zeros(n_params, self.hidden_sz)) for _ in range(2) ] cell_states2 = [ C(torch.zeros(n_params, self.hidden_sz)) for _ in range(2) ] ########################################################################## # named_shape = dict() # params_flat = [] # grads_flat = [] # for name, p in optimizee.all_named_parameters(): # named_shape[name] = p.size() # params_flat.append(p.view(-1, 1)) # grads_flat.append(C.detach(p.grad.view(-1, 1))) # params_flat = torch.cat(params_flat, 0) # grads_flat = torch.cat(grads_flat, 0) # # batch_size = 10000 # batch_sizes = [] # if batch_size > n_params: # batch_size = [n_params] # else: # batch_size = [batch_size for _ in range(n_params // batch_size)] # batch_sizes.append(n_params % batch_size) # # batch_sizes = deque(batch_sizes) # # offset = 0 # params_batches = [] # grads_batches = [] # for _ in range(len(batch_sz)): # b_size = batch_sz.popleft() # params_batches.append(params_flat[offset:offset+b_size, :]) # grads_batches.append(grads_flat[offset:offset+b_size, :]) # offset += b_size # assert(offset == params_flat.size(0)) # batches = zip(batch_sz, params_batches, grads_batches) ########################################################################## for name, p in optimizee.all_named_parameters(): cur_sz = int(np.prod(p.size())) # We do this so the gradients are disconnected from the graph # but we still get gradients from the rest gradients = C.detach(p.grad.view(cur_sz, 1)) updates, new_hidden, new_cell = self( gradients, [h[offset:offset + cur_sz] for h in hidden_states], [c[offset:offset + cur_sz] for c in cell_states]) for i in range(len(new_hidden)): hidden_states2[i][offset:offset + cur_sz] = new_hidden[i] cell_states2[i][offset:offset + cur_sz] = new_cell[i] result_params[name] = p + updates.view(*p.size()) * out_mul result_params[name].retain_grad() offset += cur_sz updates_track.append(updates) ########################################################################## updates_track = torch.cat(updates_track, 0) updates_track = updates_track.data.cpu().numpy() updates_track = np.take(updates_track, update_track_id) updates_tracks.append(updates_track) if iteration % unroll == 0: if should_train: meta_optimizer.zero_grad() all_losses.backward() meta_optimizer.step() all_losses = None detached_params = { k: C.detach(v) for k, v in result_params.items() } optimizee = C(model_cls(**detached_params)) hidden_states = [C.detach(s) for s in hidden_states2] cell_states = [C.detach(s) for s in cell_states2] else: optimizee = C(model_cls(**result_params)) assert len(list(optimizee.all_named_parameters())) hidden_states = hidden_states2 cell_states = cell_states2 timestamps.append(iter_watch.touch()) return all_losses_ever, updates_tracks, timestamps
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): data = data_cls(training=should_train) model = C(model_cls()) states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) if should_train: self.train() else: self.eval() result_dict = ResultDict() model_dt = None model_losses = 0 #grad_real_prev = None iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') batch_arg = BatchManagerArgument( params=model.params, states=StatesSlicingWrapper(states), update_prev=C(torch.zeros(model.params.size().flat())), mu=C(torch.zeros(model.params.size().flat())), sigma=C(torch.zeros(model.params.size().flat())), ) for iteration in iter_pbar: data_ = data.sample() model_loss = model(*data_) model_losses += model_loss # if iteration == 21: # import pdb; pdb.set_trace() try: make_dot(model_losses) except: import pdb pdb.set_trace() # if iteration == 1 or iteration == 21: # import pdb; pdb.set_trace() model_dt = C(model_cls(params=model.params.detach())) model_loss_dt = model_dt(*data_) # model.params.register_grad_cut_hook_() model_loss_dt.backward() assert model_dt.params.flat.grad is not None grad_cur = model_dt.params.flat.grad # model.params.remove_grad_cut_hook_() batch_arg.update(BatchManagerArgument(grad_cur=grad_cur)) # NOTE: read_only, write_only ######################################################################## batch_manager = OptimizerBatchManager(batch_arg=batch_arg) for get, set in batch_manager.iterator(): grad_cur = get.grad_cur.detach() # free-running update_prev = get.update_prev.detach() points_rel, states = self.point_sampler( grad_cur, update_prev, get.states, self.n_points) points_rel = points_rel * out_mul points_abs = get.params.detach( ) + points_rel # NOTE: check detach losses = self.loss_observer(points_abs, model_cls, data_) update = self.point_selector(points_rel, losses) set.params(get.params + update) set.states(states) set.update_prev(update) set.mu(self.point_sampler.mu) set.sigma(self.point_sampler.sigma) ########################################################################## batch_arg = batch_manager.batch_arg_to_set iter_pbar.set_description('optim_iteration[loss:{:10.6f}]'.format( model_loss.tolist())) if should_train and iteration % unroll == 0: meta_optimizer.zero_grad() try: model_losses.backward() except: import pdb pdb.set_trace() meta_optimizer.step() if not should_train or iteration % unroll == 0: model_losses = 0 batch_arg.detach_() # batch_arg.params.detach_() # batch_arg.states.detach_() # NOTE: just do batch_arg.detach_() model = C(model_cls(params=batch_arg.params)) result_dict.append( step_num=iteration, loss=model_loss, ) if not should_train: result_dict.append(walltime=iter_watch.touch(), **self.params_tracker( grad=grad_cur, update=batch_arg.update_prev, mu=batch_arg.mu, sigma=batch_arg.sigma, )) return result_dict