def forward(self, x, states, mask=None): assert isinstance(x, (tuple, list, torch.Tensor)) if isinstance(x, (list, tuple)): assert len(x) == len(self.preproc) x_p = [] for bool, inp in zip(self.preproc, x): inp = C(self.process(inp)) if bool else C(inp) x_p.append(inp) x = torch.cat(x_p, dim=1) elif isinstance(x, torch.Tensor) and self.preproc: assert len(self.preproc) == 1 x = C(self.process(x)) out_states = [] if self.rnn_cell == 'lstm': for i, rnn in enumerate(self.rnn): out_states.append( LSTMStates(*rnn(x, (states[i].h, states[i].c)))) x = out_states[-1].h elif self.rnn_cell == 'gru': for i, rnn in enumerate(self.rnn): out_states.append(rnn(x, states[i], mask)) x = out_states[-1] else: raise RuntimeError(f'Unknown rnn_cell type: {self.rnn_cell}') return self.output(x), OptimizerStates(out_states)
def forward(self, inp, states, n_points): if self.preproc: inp = C(self.preproc(inp)) s_0 = LSTMStates(*self.recurs(inp, (states[0].h, states[0].c))) s_1 = LSTMStates(*self.recurs2(s_0.h, (states[1].h, states[1].c))) mu_logvar = self.mu_logvar(s_1.h) mu, logvar = mu_logvar[0], mu_logvar[1] points = self._sample(mu, logvar, n_points) # points: [n_points, 1] return points, OptimizerStates([s_0, s_1])
def forward(self, g_prev, u_prev, g_states, u_states): if self.preproc: g_prev = C(self.preproc(g_prev)) u_prev = C(self.preproc(u_prev)) g_u_cat = torch.cat([g_prev, u_prev], dim=1) # [n_param, 4 or 2 (g+u)] g_s_0 = LSTMStates( *self.g_rnn(g_u_cat, (g_states[0].h, g_states[0].c))) g_s_1 = LSTMStates( *self.g_rnn2(g_s_0.h, (g_states[1].h, g_states[1].c))) g_cur = self.g_output(g_s_1.h) * 0.1 g_states_ = OptimizerStates([g_s_0, g_s_1]) if self.preproc: g_cur_p = C(self.preproc(g_cur.detach())) # NOTE: Detach! u_s_0 = LSTMStates( *self.u_rnn(g_cur_p, (u_states[0].h, u_states[0].c))) u_s_1 = LSTMStates( *self.u_rnn2(u_s_0.h, (u_states[1].h, u_states[1].c))) u_cur = self.u_output(u_s_1.h) u_states_ = OptimizerStates([u_s_0, u_s_1]) return g_cur, u_cur, g_states_, u_states_
def forward(self, inp, states): if self.preproc: inp = C(self.preproc(inp)) if self.rnn_cell == 'lstm': s_0 = LSTMStates(*self.recurs(inp, (states[0].h, states[0].c))) s_1 = LSTMStates(*self.recurs2(s_0.h, (states[1].h, states[1].c))) out = s_1.h elif self.rnn_cell == 'gru': s_0 = self.recurs(inp, states[0]) s_1 = self.recurs2(s_0, states[1]) out = s_1 else: raise RuntimeError(f'Unknown rnn_cell type: {self.rnn_cell}') return self.output(out), OptimizerStates([s_0, s_1])
def forward(self, g_cur, u_prev, states, n_points): if self.preproc: #g_cur = C(self.preproc(g_cur)) u_prev = C(self.preproc(u_prev)) #inp = torch.cat([g_cur, u_prev], dim=1) # [n_params, 4 or 2] inp = u_prev s_0 = LSTMStates(*self.recurs(inp, (states[0].h, states[0].c))) s_1 = LSTMStates(*self.recurs2(s_0.h, (states[1].h, states[1].c))) mu_logvar = self.mu_logvar(s_1.h) # [n_params, 2] self._mu = mu_logvar[:, 0].unsqueeze(1) self._sigma = self.softplus(mu_logvar[:, 1]).unsqueeze( 1) #.mul(0.5).exp_().unsqueeze(1) #self._mu, self._sigma = mu_logvar[0], mu_logvar[1].mul(0.5).exp_() # mu: [n_params, 1] / sigma: [n_params, 1] points_size = [n_points, *self._sigma.size()] eps = torch.tensor(self._sigma.data.new(size=points_size).normal_()) points_rel = (self._mu + self._sigma * eps) # points: [n_points, n_params, 1] return points_rel, OptimizerStates([s_0, s_1])
def forward(self, g_prev, u_prev, params, g_states): if self.preproc: g_prev = C(self.preproc(self.ig_drop(g_prev))) u_prev = C(self.preproc(self.iu_drop(u_prev))) params = C(self.preproc(self.ip_drop(params))) g_u_cat = torch.cat([g_prev, u_prev, params], dim=1) # [n_param, 4 or 2 (g+u)] g_s_0 = LSTMStates( *self.g_rnn(g_u_cat, (g_states[0].h, g_states[0].c))) g_s_0_h = self.i2_drop(g_s_0.h) g_s_1 = LSTMStates( *self.g_rnn2(g_s_0_h, (g_states[1].h, g_states[1].c))) mu_logvar = self.g_output(g_s_1.h) * 0.1 self._mu = mu_logvar[:, 0].unsqueeze(1) self._sigma = self.softplus(mu_logvar[:, 1]).unsqueeze(1) eps = torch.tensor( self._sigma.data.new(size=self._sigma.size()).normal_()) g_cur = (self._mu + self._sigma * eps) g_states_ = OptimizerStates([g_s_0, g_s_1]) return self.output_drop(g_cur), self.state_drop(g_states_)
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): data = data_cls(training=should_train) model = C(model_cls()) optimizer_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) if should_train: self.train() else: self.eval() result_dict = ResultDict() unroll_losses = 0 update = C(torch.zeros(model.params.size().flat)) iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') batch_arg = BatchManagerArgument( params=model.params, states=StatesSlicingWrapper(optimizer_states), updates=update, ) for iteration in iter_pbar: data_ = data.sample() loss = model(*data_) unroll_losses += loss model_dt = C(model_cls(params=batch_arg.params.detach())) model_dt_loss = model_dt(*data_) assert loss == model_dt_loss model_dt_loss.backward() assert model_dt.params.flat.grad is not None grad = model_dt.params.flat.grad batch_arg.update( BatchManagerArgument(grad=grad).immutable().volatile()) iter_pbar.set_description(f'optim_iteration[loss:{loss}]') ########################################################################## batch_manager = OptimizerBatchManager(batch_arg=batch_arg) for get, set in batch_manager.iterator(): updates, new_states = self(get.grad.detach(), get.states) set.states(new_states) set.params(get.params + updates * out_mul) if not should_train: set.updates(updates) ########################################################################## batch_arg = batch_manager.batch_arg_to_set if should_train and iteration % unroll == 0: meta_optimizer.zero_grad() unroll_losses.backward() meta_optimizer.step() if not should_train or iteration % unroll == 0: unroll_losses = 0 batch_arg.detach_() # model.params = model.params.detach() # optimizer_states = optimizer_states.detach() model = C(model_cls(params=batch_arg.params)) result_dict.append(loss=loss) if not should_train: result_dict.append(walltime=iter_watch.touch(), **self.params_tracker( grad=grad, update=batch_arg.updates, )) return result_dict
def forward(self, inp, states): if self.preproc: inp = C(self.preproc(inp)) s_0 = LSTMStates(*self.recurs(inp, (states[0].h, states[0].c))) s_1 = LSTMStates(*self.recurs2(s_0.h, (states[1].h, states[1].c))) return self.output(s_1.h), OptimizerStates([s_0, s_1])
def meta_optimize(self, meta_optimizer, data, model_cls, optim_it, unroll, out_mul, mode): assert mode in ['train', 'valid', 'test'] if mode == 'train': self.train() # data.new_train_data() inner_data = data.loaders['inner_train'] outer_data = data.loaders['inner_valid'] # inner_data = data.loaders['train'] # outer_data = inner_data drop_mode = 'soft_drop' elif mode == 'valid': self.eval() inner_data = data.loaders['valid'] outer_data = None drop_mode = 'hard_drop' elif mode == 'eval': self.eval() inner_data = data.loaders['test'] outer_data = None drop_mode = 'hard_drop' # data = dataset(mode=mode) model = C(model_cls(sb_mode=self.sb_mode)) model(*data.pseudo_sample()) # layer_sz = sum([v for v in model.params.size().unflat(1, 'mat').values()]) mask_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.activations.mean(0)), self.hidden_sz))) update_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) result_dict = ResultDict() unroll_losses = 0 walltime = 0 update = C(torch.zeros(model.params.size().flat())) batch_arg = BatchManagerArgument( params=model.params, states=StatesSlicingWrapper(update_states), updates=model.params.new_zeros(model.params.size()), ) params = model.params lr_sgd = 0.2 lr_adam = 0.001 sgd = torch.optim.SGD(model.parameters(), lr=lr_sgd) adam = torch.optim.Adam(model.parameters()) adagrad = torch.optim.Adagrad(model.parameters()) # print('###', lr_adam, '####') iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') bp_watch = utils.StopWatch('bp') loss_decay = 1.0 dist_list = [] dense_rank = 0 sparse_rank = 0 dense_rank_list = [] sparse_rank_list = [] for iteration in iter_pbar: iter_watch.touch() # # conventional optimizers with API # model.zero_grad() # loss = model(*inner_data.load()) # loss.backward() # adam.step() # loss_detached = loss model_detached = C(model_cls(params=params.detach())) loss_detached = model_detached(*inner_data.load()) loss_detached.backward() # # conventional optimizers with manual SGD # params = model_detached.params + model_detached.params.grad * (-0.2) # iter_pbar.set_description(f'optim_iteration[loss:{loss_detached}]') masks = [] # sparse_old_loss = [] # sparse_new_loss = [] sparse_loss = [] dense_loss = [] candidates = [] mode = 2 n_sample = 10 if mode == 1: lr = 1.0 elif mode == 2: lr = 0.2 else: raise Exception() for i in range(n_sample): if i == 9: mask = MaskGenerator.topk( model_detached.activations.grad.detach()) else: mask = MaskGenerator.randk( model_detached.activations.grad.detach()) if model == 1: # recompute gradients again using prunned network sparse_params = model_detached.params.prune(mask) model_prunned = C(model_cls(params=sparse_params.detach())) loss_prunned = model_prunned(*inner_data.load()) loss_prunned.backward() grads = model_prunned.params.grad elif mode == 2: # use full gradients but sparsified sparse_params = model_detached.params.prune(mask) grads = model_detached.params.grad.prune(mask) else: raise Exception('unknown model.') sparse_params = sparse_params + grads * (-lr) model_prunned = C(model_cls(params=sparse_params.detach())) loss_prunned = model_prunned(*inner_data.load()) # params = model_detached.params params = model_detached.params.sparse_update( mask, grads * (-lr)) candidates.append(params) model_dense = C(model_cls(params=params)) loss_dense = model_dense(*inner_data.load()) sparse_loss.append(loss_prunned) dense_loss.append(loss_dense) sparse_loss = torch.stack(sparse_loss, 0) dense_loss = torch.stack(dense_loss, 0) min_id = ( sparse_loss.min() == sparse_loss).nonzero().squeeze().tolist() params = candidates[min_id] sparse_order = sparse_loss.sort()[1].float() dense_order = dense_loss.sort()[1].float() dist = (sparse_order - dense_order).abs().sum() dist_list.append(dist) dense_rank += (dense_order == 9).nonzero().squeeze().tolist() sparse_rank += (sparse_order == 9).nonzero().squeeze().tolist() dense_rank_list.append( (dense_order == 9).nonzero().squeeze().tolist()) sparse_rank_list.append( (sparse_order == 9).nonzero().squeeze().tolist()) iter_pbar.set_description( f'optim_iteration[dense_loss:{loss_dense.tolist():5.5}/sparse_loss:{loss_prunned.tolist():5.5}]' ) # iter_pbar.set_description(f'optim_iteration[loss:{loss_dense.tolist()}/dist:{dist.tolist()}]') # torch.optim.SGD(sparse_params, lr=0.1).step() result_dict.append(loss=loss_detached) if not mode == 'train': result_dict.append( walltime=walltime, # **self.params_tracker( # grad=grad, # update=batch_arg.updates, # ) ) dist_mean = torch.stack(dist_list, 0).mean() import pdb pdb.set_trace() return result_dict
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): data = data_cls(training=should_train) model = C(model_cls()) g_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) if should_train: self.train() else: self.eval() result_dict = ResultDict() model_dt = None grad_losses = 0 train_with_lazy_tf = True eval_test_grad_loss = True grad_prev = C(torch.zeros(model.params.size().flat())) update_prev = C(torch.zeros(model.params.size().flat())) #grad_real_prev = None grad_real_cur = None grad_real_prev = C(torch.zeros(model.params.size().flat())) iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') """grad_real_prev: teacher-forcing current input of RNN. required during all of training steps + every 'k' test steps. grad_real_cur: MSE loss target for current output of RNN. required during all the training steps only. """ batch_arg = BatchManagerArgument( params=model.params, grad_prev=grad_prev, update_prev=update_prev, g_states=StatesSlicingWrapper(g_states), mse_loss=C(torch.zeros(1)), sigma=C(torch.zeros(model.params.size().flat())), #updates=OptimizerBatchManager.placeholder(model.params.flat.size()), ) for iteration in iter_pbar: data_ = data.sample() # if grad_real_cur is not None: # if iteration % self.k == 0 # grad_cur_prev = grad_real_cur # back up previous grad if (iteration + 1) % self.k == 0 or should_train or eval_test_grad_loss: # get grad every step while training # get grad one step earlier than teacher-forcing steps while testing model_dt = C(model_cls(params=model.params.detach())) model_loss_dt = model_dt(*data_) # model.params.register_grad_cut_hook_() model_loss_dt.backward() # model.params.remove_grad_cut_hook_() grad_real_cur = model_dt.params.flat.grad assert grad_real_cur is not None if should_train or eval_test_grad_loss: # mse loss target # grad_real_cur is not used as a mse target while testing assert grad_real_cur is not None batch_arg.update( BatchManagerArgument( grad_real_cur=grad_real_cur).immutable()) if iteration % self.k == 0 or (should_train and not train_with_lazy_tf): # teacher-forcing assert grad_real_prev is not None batch_arg.update( BatchManagerArgument( grad_real_prev=grad_real_prev).immutable().volatile()) grad_real_prev = None ######################################################################## batch_manager = OptimizerBatchManager(batch_arg=batch_arg) for get, set in batch_manager.iterator(): if iteration % self.k == 0 or (should_train and not train_with_lazy_tf): grad_prev = get.grad_real_prev # teacher-forcing else: grad_prev = get.grad_prev # free-running grad_cur, g_states = self(grad_prev.detach(), get.update_prev.detach(), get.params.detach(), get.g_states) set.grad_prev(grad_cur) set.g_states(g_states) set.params(get.params - 1.0 * grad_cur.detach()) set.sigma(self.sigma) if should_train or eval_test_grad_loss: set.mse_loss( self.mse_loss(grad_cur, get.grad_real_cur.detach())) ########################################################################## batch_arg = batch_manager.batch_arg_to_set grad_real = batch_manager.batch_arg_to_get.grad_real_cur.detach() cosim_loss = nn.functional.cosine_similarity(batch_arg.grad_prev, grad_real, dim=0) grad_loss = batch_arg.mse_loss * 100 cosim_loss = -torch.log((cosim_loss + 1) * 0.5) * 100 grad_losses += (grad_loss + cosim_loss) #to_float = lambda x: float(x.data.cpu().numpy()) sigma = batch_arg.sigma.detach().mean() iter_pbar.set_description( 'optim_iteration[loss(model):{:10.6f}/(mse):{:10.6f}/(cosim):{:10.6f}/(sigma):{:10.6f}]' ''.format(model_loss_dt.tolist(), grad_loss.tolist()[0], cosim_loss.tolist()[0], sigma.tolist())) if should_train and iteration % unroll == 0: w_decay_loss = 0 for p in self.parameters(): w_decay_loss += torch.sum(p * p) w_decay_loss = (w_decay_loss * 0.5) * 1e-4 total_loss = grad_losses + w_decay_loss meta_optimizer.zero_grad() total_loss.backward() meta_optimizer.step() if not should_train or iteration % unroll == 0: model_losses = 0 grad_losses = 0 batch_arg.detach_() # batch_arg.params.detach_() # batch_arg.g_states.detach_() # batch_arg.u_states.detach_() result_dict.append( step_num=iteration, loss=model_loss_dt, grad_loss=grad_loss, ) if not should_train: result_dict.append(walltime=iter_watch.touch(), **self.params_tracker( grad_real=grad_real_cur, grad_pred=batch_arg.grad_prev, update=batch_arg.update_prev, sigma=batch_arg.sigma, )) if (iteration + 1) % self.k == 0 or should_train or eval_test_grad_loss: grad_real_prev = grad_real_cur grad_real_cur = None return result_dict
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): data = data_cls(training=should_train) model = C(model_cls()) g_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) u_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) if should_train: self.train() else: self.eval() grad_pred_tracker = ParamsIndexTracker(name='grad_pred', n_tracks=10, n_params=len(model.params)) grad_real_tracker = grad_pred_tracker.clone_new_name('grad_real') update_tracker = grad_pred_tracker.clone_new_name('update') result_dict = ResultDict() model_dt = None model_losses = 0 grad_losses = 0 train_with_lazy_tf = False eval_test_grad_loss = True grad_prev = C(torch.zeros(model.params.size().flat())) update_prev = C(torch.zeros(model.params.size().flat())) #grad_real_prev = None grad_real_cur = None grad_real_prev = C(torch.zeros(model.params.size().flat())) iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') """grad_real_prev: teacher-forcing current input of RNN. required during all of training steps + every 'k' test steps. grad_real_cur: MSE loss target for current output of RNN. required during all the training steps only. """ batch_arg = BatchManagerArgument( params=model.params, grad_prev=grad_prev, update_prev=update_prev, g_states=StatesSlicingWrapper(g_states), u_states=StatesSlicingWrapper(u_states), mse_loss=C(torch.zeros(1)), #updates=OptimizerBatchManager.placeholder(model.params.flat.size()), ) for iteration in iter_pbar: inp, out = data.sample() model_loss = model(inp, out) model_losses += model_loss # if grad_real_cur is not None: # if iteration % self.k == 0 # grad_cur_prev = grad_real_cur # back up previous grad if (iteration + 1) % self.k == 0 or should_train or eval_test_grad_loss: # get grad every step while training # get grad one step earlier than teacher-forcing steps while testing model_dt = C(model_cls(params=model.params.detach())) model_loss_dt = model_dt(inp, out) # model.params.register_grad_cut_hook_() model_loss_dt.backward() # model.params.remove_grad_cut_hook_() grad_real_cur = model_dt.params.flat.grad assert grad_real_cur is not None del model_dt, model_loss_dt if should_train or eval_test_grad_loss: # mse loss target # grad_real_cur is not used as a mse target while testing assert grad_real_cur is not None batch_arg.update( BatchManagerArgument( grad_real_cur=grad_real_cur).immutable().volatile()) if iteration % self.k == 0 or (should_train and not train_with_lazy_tf): # teacher-forcing assert grad_real_prev is not None batch_arg.update( BatchManagerArgument( grad_real_prev=grad_real_prev).immutable().volatile()) grad_real_prev = None ######################################################################## batch_manager = OptimizerBatchManager(batch_arg=batch_arg) for get, set in batch_manager.iterator(): if iteration % self.k == 0 or (should_train and not train_with_lazy_tf): grad_prev = get.grad_real_prev.detach() # teacher-forcing else: grad_prev = get.grad_prev.detach() # free-running update_prev = get.update_prev.detach() grad_cur, update_cur, g_states, u_states = self( grad_prev, update_prev, get.g_states, get.u_states) set.grad_prev(grad_cur) set.update_prev(update_cur) set.g_states(g_states) set.u_states(u_states) set.params(get.params + update_cur * out_mul) if should_train or eval_test_grad_loss: set.mse_loss( self.mse_loss(grad_cur, get.grad_real_cur.detach())) ########################################################################## batch_arg = batch_manager.batch_arg_to_set grad_loss = batch_arg.mse_loss * 100 grad_losses += grad_loss #to_float = lambda x: float(x.data.cpu().numpy()) iter_pbar.set_description( 'optim_iteration[loss(model):{:10.6f}/(grad):{:10.6f}]'.format( model_loss.tolist(), grad_loss.tolist()[0])) if should_train and iteration % unroll == 0: meta_optimizer.zero_grad() losses = model_losses + grad_losses losses.backward() meta_optimizer.step() if not should_train or iteration % unroll == 0: model_losses = 0 grad_losses = 0 batch_arg.detach_() # batch_arg.params.detach_() # batch_arg.g_states.detach_() # batch_arg.u_states.detach_() model = C(model_cls(params=batch_arg.params)) result_dict.append( step_num=iteration, loss=model_loss, grad_loss=grad_loss, ) if not should_train: result_dict.append( walltime=iter_watch.touch(), **grad_pred_tracker.track_num, **grad_pred_tracker(batch_arg.grad_prev), **grad_real_tracker(grad_real_cur), **update_tracker(batch_arg.update_prev), ) if (iteration + 1) % self.k == 0 or should_train or eval_test_grad_loss: grad_real_prev = grad_real_cur grad_real_cur = None return result_dict
def meta_optimize(self, meta_optim, data, model_cls, optim_it, unroll, out_mul, k_obsrv=None, no_mask=None, writer=None, mode='train'): """FIX LATER: do something about the dummy arguments.""" assert mode in ['train', 'valid', 'test'] self.set_mode(mode) use_indexer = False params = C(model_cls()).params states = C( OptimizerStates.initial_zeros(size=(self.n_layers, len(params), self.hidden_sz), rnn_cell=self.rnn_cell)) result_dict = ResultDict() unroll_losses = 0 walltime = Walltime() update = C(torch.zeros(params.size().flat())) if use_indexer: batch_arg = BatchManagerArgument( params=params, states=StatesSlicingWrapper(states), updates=update, ) iter_pbar = tqdm(range(1, optim_it + 1), 'inner_train') for iter in iter_pbar: with WalltimeChecker(walltime): model_train = C(model_cls(params=params.detach())) train_nll, train_acc = model_train(*data['in_train'].load()) train_nll.backward() if model_train.params.flat.grad is not None: g = model_train.params.flat.grad else: g = model_train._grad2params(model_train.params) assert g is not None if use_indexer: # This indexer was originally for outer-level batch split # in case that the whold parameters do not fit in the VRAM. # Which is not good, unless unavoidable, since you cannot avoid # additional time cost induced by the serial computation. batch_arg.update( BatchManagerArgument(grad=g).immutable().volatile()) indexer = DefaultIndexer(input=g, shuffle=False) batch_manager = OptimizerBatchManager(batch_arg=batch_arg, indexer=indexer) ###################################################################### for get, set in batch_manager.iterator(): updates, new_states = self(get.grad().detach(), get.states()) set.states(new_states) set.params(get.params() + updates * out_mul) if not mode == 'train': set.updates(updates) ###################################################################### batch_arg = batch_manager.batch_arg_to_set params = batch_arg.params else: updates, states = self(g.detach(), states) updates = params.new_from_flat(updates) params = params + updates * out_mul #params = params - params.new_from_flat(g) * 0.1 with WalltimeChecker(walltime if mode == 'train' else None): model_test = C(model_cls(params=params)) test_nll, test_acc = model_test(*data['in_test'].load()) if mode == 'train': unroll_losses += test_nll if iter % unroll == 0: meta_optim.zero_grad() unroll_losses.backward() nn.utils.clip_grad_value_(self.parameters(), 0.01) meta_optim.step() unroll_losses = 0 with WalltimeChecker(walltime): if not mode == 'train' or iter % unroll == 0: if use_indexer: batch_arg.detach_() else: params.detach_() states.detach_() # model.params = model.params.detach() # states = states.detach() result = dict( train_nll=train_nll.tolist(), test_nll=test_nll.tolist(), train_acc=train_acc.tolist(), test_acc=test_acc.tolist(), walltime=walltime.time, ) result_dict.append(result) log_pbar(result, iter_pbar) return result_dict, params
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): data = data_cls(training=should_train) model = C(model_cls()) optimizer_states = C(OptimizerStates.initial_zeros( size=(self.n_layers, len(model.params), self.hidden_sz), rnn_cell=self.rnn_cell)) if should_train: self.train() else: self.eval() result_dict = ResultDict() walltime = 0 unroll_losses = 0 update = C(torch.zeros(model.params.size().flat())) batch_arg = BatchManagerArgument( params=model.params, states=StatesSlicingWrapper(optimizer_states), updates=update, ) iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') for iteration in iter_pbar: iter_watch.touch() # entire_loop.touch() # interval.touch() data_ = data.sample() if should_train: model = C(model_cls(params=batch_arg.params)) loss = model(*data_) unroll_losses += loss # print("\n[1] : "+str(interval.touch('interval', False))) # model_ff.touch() model_detached = C(model_cls(params=batch_arg.params.detach())) loss_detached = model_detached(*data_) # model_ff.touch('interval', True) if should_train: assert loss == loss_detached # model_bp.touch() loss_detached.backward() # model_bp.touch('interval', True) # interval.touch() assert model_detached.params.flat.grad is not None grad = model_detached.params.flat.grad batch_arg.update(BatchManagerArgument(grad=grad).immutable().volatile()) iter_pbar.set_description(f'optim_iteration[loss:{loss_detached}]') # print("\n[2-1] : "+str(interval.touch('interval', False))) ########################################################################## # indexer = DefaultIndexer(input=grad, shuffle=False) indexer = TopKIndexer(input=grad, k_ratio=0.1, random_k_mode=False)# bound=model_dt.params.size().bound_layerwise()) #indexer = DefaultIndexer(input=grad) batch_manager = OptimizerBatchManager( batch_arg=batch_arg, indexer=indexer) # print("\n[2-2] : "+str(interval.touch('interval', False))) # optim_ff.touch() # optim_ff_1.touch() for get, set in batch_manager.iterator(): # optim_ff_1.touch('interval', True) # optim_ff_2.touch() # optim_ff_2.touch('interval', True) # optim_ff_3.touch() updates, new_states = self(get.grad().detach(), get.states()) # optim_ff_3.touch('interval', True) # optim_ff_4.touch() updates = updates * out_mul set.states(new_states) set.params(get.params() + updates) if not should_train: set.updates(updates) # optim_ff_4.touch('interval', True) # optim_ff.touch('interval', True) ########################################################################## # interval.touch() batch_arg = batch_manager.batch_arg_to_set # print("\n[3] : "+str(interval.touch('interval', False))) if should_train and iteration % unroll == 0: # optim_bp.touch() meta_optimizer.zero_grad() unroll_losses.backward() nn.utils.clip_grad_norm_(self.parameters(), 10) meta_optimizer.step() unroll_losses = 0 # optim_bp.touch('interval', True) # interval.touch() if not should_train or iteration % unroll == 0: batch_arg.detach_() # print("\n[4] : "+str(interval.touch('interval', False))) # interval.touch() # model.params = model.params.detach() # optimizer_states = optimizer_states.detach() walltime += iter_watch.touch('interval') result_dict.append(loss=loss_detached) # print("\n[5] : "+str(interval.touch('interval', False))) if not should_train: result_dict.append( walltime=iter_watch.touch(), **self.params_tracker( grad=grad, update=batch_arg.updates, ) ) # entire_loop.touch('interval', True) return result_dict
def meta_optimize(self, meta_optimizer, data, model_cls, optim_it, unroll, out_mul, mode): assert mode in ['train', 'valid', 'test'] if mode == 'train': self.train() # data.new_train_data() inner_data = data.loaders['inner_train'] outer_data = data.loaders['inner_valid'] # inner_data = data.loaders['train'] # outer_data = inner_data drop_mode = 'soft_drop' elif mode == 'valid': self.eval() inner_data = data.loaders['valid'] outer_data = None drop_mode = 'hard_drop' elif mode == 'eval': self.eval() inner_data = data.loaders['test'] outer_data = None drop_mode = 'hard_drop' # data = dataset(mode=mode) model = C(model_cls(sb_mode=self.sb_mode)) model(*data.pseudo_sample()) # layer_sz = sum([v for v in model.params.size().unflat(1, 'mat').values()]) mask_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.activations.mean(0)), self.hidden_sz))) update_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) result_dict = ResultDict() unroll_losses = 0 walltime = 0 update = C(torch.zeros(model.params.size().flat())) batch_arg = BatchManagerArgument( params=model.params, states=StatesSlicingWrapper(update_states), updates=model.params.new_zeros(model.params.size()), ) iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') bp_watch = utils.StopWatch('bp') loss_decay = 1.0 for iteration in iter_pbar: iter_watch.touch() model_detached = C( model_cls(params=batch_arg.params.detach(), sb_mode=self.sb_mode)) loss_detached = model_detached(*inner_data.load()) loss_detached.backward() # import pdb; pdb.set_trace() iter_pbar.set_description(f'optim_iteration[loss:{loss_detached}]') # np.set_printoptions(suppress=True, precision=3, linewidth=200, edgeitems=20) if debug_sigint.signal_on: debug = iteration == 1 or iteration % 10 == 0 # debug = True else: debug = False backprop_mask, mask_states, sparsity_loss = self.mask_generator( activations=model_detached.activations.grad.detach(), debug=debug, states=mask_states) # if debug_sigstp.signal_on and (iteration == 1 or iteration % 10 == 0): # mask_list.append(backprop_mask) # import pdb; pdb.set_trace() # min = 0.001 # max = 0.01 unroll_losses += sparsity_loss * 0.01 * \ loss_decay # * ((max - min) * lambda_ + min) # import pdb; pdb.set_trace() # model_detached.apply_backprop_mask(backprop_mask.unflat, drop_mode) # loss_detached.backward() assert model_detached.params.grad is not None if mode == 'train': model = C( model_cls(params=batch_arg.params, sb_mode=self.sb_mode)) loss = model(*outer_data.load()) # assert loss == loss_detached if iteration == 1: initial_loss = loss loss_decay = 1.0 else: loss_decay = (loss / initial_loss).detach() unroll_losses += loss # model_detached.apply_backprop_mask(backprop_mask.unflat, drop_mode) #model_detached.params.apply_grad_mask(mask, 'soft_drop') coord_mask = backprop_mask.expand_as_params(model_detached.params) if drop_mode == 'soft_drop': # grad = model_detached.params.grad.mul(coord_mask).flat # model_detached.params = model_detached.params.mul(params_mask) index = None # updates, states = self.updater(grad, states, mask=coord_mask) # updates = updates * out_mul * coord_mask # import pdb; pdb.set_trace() elif drop_mode == 'hard_drop': index = (coord_mask.flat.squeeze() > 0.5).nonzero().squeeze() grad = model_detached.params.grad.flat batch_arg.update(BatchManagerArgument(grad=grad, mask=coord_mask)) ########################################################################## indexer = DefaultIndexer(input=grad, index=index, shuffle=False) batch_manager = OptimizerBatchManager(batch_arg=batch_arg, indexer=indexer) for get, set in batch_manager.iterator(): inp = get.grad().detach() #inp = [get.grad().detach(), get.tag().detach()] if drop_mode == 'soft_drop': updates, new_states = self.updater( inp, get.states()) # , mask=get.mask()) # if iteration == 100: # import pdb; pdb.set_trace() # aaa = get.states() *2 updates = updates * out_mul * get.mask() set.states(get.states() * (1 - get.mask()) + new_states * get.mask()) # set.states(new_states) elif drop_mode == 'hard_drop': updates, new_states = self.updater(inp, get.states()) updates = updates * out_mul set.states(new_states) else: raise RuntimeError('Unknown drop mode!') set.params(get.params() + updates) # if not mode != 'train': # set.updates(updates) ########################################################################## batch_arg = batch_manager.batch_arg_to_set # updates = model.params.new_from_flat(updates) if mode == 'train' and iteration % unroll == 0: meta_optimizer.zero_grad() unroll_losses.backward() nn.utils.clip_grad_norm_(self.parameters(), 5) meta_optimizer.step() unroll_losses = 0 # if iteration == 1 or iteration % 10 == 0: # import pdb; pdb.set_trace() if not mode == 'train' or iteration % unroll == 0: mask_states.detach_() batch_arg.detach_() # model.params = model.params.detach() # update_states = update_states.detach() walltime += iter_watch.touch('interval') result_dict.append(loss=loss_detached) if not mode == 'train': result_dict.append( walltime=walltime, # **self.params_tracker( # grad=grad, # update=batch_arg.updates, # ) ) return result_dict
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): target = data_cls(training=should_train) optimizee = C(model_cls()) optimizer_states = OptimizerStates.zero_states( self.n_layers, len(optimizee.params), self.hidden_sz) if should_train: self.train() else: self.eval() coordinate_tracker = ParamsIndexTracker( n_params=len(optimizee.params), n_tracks=30) result_dict = ResultDict() unroll_losses = 0 iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') for iteration in iter_pbar: loss = optimizee(target) unroll_losses += loss loss.backward(retain_graph=should_train) iter_pbar.set_description(f'optim_iteration[loss:{loss}]') unit_size = torch.Size([len(optimizee.params), 1]) grid_size = torch.Size([len(optimizee.params), self.n_coord]) batch_manager = OptimizerBatchManager( params=optimizee.params, states=StatesSlicingWrapper(optimizer_states), units=OptimizerBatchManager.placeholder(unit_size), grid_abs=OptimizerBatchManager.placeholder(grid_size), grid_rel=OptimizerBatchManager.placeholder(grid_size), batch_size=30000, ) ########################################################################## for get, set in batch_manager.iterator(): units, new_states = self.unit_maker( C.detach(get.params.grad), get.states.clone()) units = units * out_mul / self.n_coord / 2 grid_abs, grid_rel = self.grid_maker(units, get.params) set.states(new_states) set.units(units) set.grid_abs(grid_abs) set.grid_rel(grid_rel) ########################################################################## landscape = self.loss_observer(batch_manager.grid_abs, model_cls, target) updates = self.step_maker(batch_manager.grid_rel, landscape) optimizee.params.add_flat_(updates) if should_train and iteration % unroll == 0: meta_optimizer.zero_grad() unroll_losses.backward() meta_optimizer.step() if not should_train or iteration % unroll == 0: unroll_losses = 0 optimizee.params = optimizee.params.detach() optimizer_states = optimizer_states.detach() optimizee = C(model_cls(optimizee.params)) result_dict.append(loss=loss) if not should_train: result_dict.append( grid=coordinate_tracker(batch_manager.units), update=coordinate_tracker(updates), walltime=iter_watch.touch(), ) return result_dict
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): data = data_cls(training=should_train) model = C(model_cls()) states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) if should_train: self.train() else: self.eval() result_dict = ResultDict() model_dt = None model_losses = 0 #grad_real_prev = None iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') batch_arg = BatchManagerArgument( params=model.params, states=StatesSlicingWrapper(states), update_prev=C(torch.zeros(model.params.size().flat())), mu=C(torch.zeros(model.params.size().flat())), sigma=C(torch.zeros(model.params.size().flat())), ) for iteration in iter_pbar: data_ = data.sample() model_loss = model(*data_) model_losses += model_loss # if iteration == 21: # import pdb; pdb.set_trace() try: make_dot(model_losses) except: import pdb pdb.set_trace() # if iteration == 1 or iteration == 21: # import pdb; pdb.set_trace() model_dt = C(model_cls(params=model.params.detach())) model_loss_dt = model_dt(*data_) # model.params.register_grad_cut_hook_() model_loss_dt.backward() assert model_dt.params.flat.grad is not None grad_cur = model_dt.params.flat.grad # model.params.remove_grad_cut_hook_() batch_arg.update(BatchManagerArgument(grad_cur=grad_cur)) # NOTE: read_only, write_only ######################################################################## batch_manager = OptimizerBatchManager(batch_arg=batch_arg) for get, set in batch_manager.iterator(): grad_cur = get.grad_cur.detach() # free-running update_prev = get.update_prev.detach() points_rel, states = self.point_sampler( grad_cur, update_prev, get.states, self.n_points) points_rel = points_rel * out_mul points_abs = get.params.detach( ) + points_rel # NOTE: check detach losses = self.loss_observer(points_abs, model_cls, data_) update = self.point_selector(points_rel, losses) set.params(get.params + update) set.states(states) set.update_prev(update) set.mu(self.point_sampler.mu) set.sigma(self.point_sampler.sigma) ########################################################################## batch_arg = batch_manager.batch_arg_to_set iter_pbar.set_description('optim_iteration[loss:{:10.6f}]'.format( model_loss.tolist())) if should_train and iteration % unroll == 0: meta_optimizer.zero_grad() try: model_losses.backward() except: import pdb pdb.set_trace() meta_optimizer.step() if not should_train or iteration % unroll == 0: model_losses = 0 batch_arg.detach_() # batch_arg.params.detach_() # batch_arg.states.detach_() # NOTE: just do batch_arg.detach_() model = C(model_cls(params=batch_arg.params)) result_dict.append( step_num=iteration, loss=model_loss, ) if not should_train: result_dict.append(walltime=iter_watch.touch(), **self.params_tracker( grad=grad_cur, update=batch_arg.update_prev, mu=batch_arg.mu, sigma=batch_arg.sigma, )) return result_dict