def model_analyzer(optim, mode, model_train, params, model_cls, set_size, data, iter, optim_it, analyze_mask=False, sample_mask=False, draw_loss=False): text_dir = 'analyze_model/analyze_mask' result_dir = 'analyze_model/drawloss' sample_dir = 'analyze_model/mask_compare' mask_dict = ResultDict() iter_interval = 10 sample_num = 10000 is_mask_dict = False if mode == 'test' and analyze_mask: analyzing_mask(optim.mask_gen, set_size, mode, iter, iter_interval, text_dir) if mode == 'test' and sample_mask: mask_result = sampling_mask(optim.mask_gen, set_size, model_train, params, sample_num, mode, iter, iter_interval, sample_dir) if mask_result is not None: mask_dict.append(mask_result) is_mask_dict = True if mode == 'test' and draw_loss: plot_loss(model_cls=model_cls, model=model_train, params=params, input_data=data['in_train'].load(), dataset=data['in_train'], feature_gen=optim.feature_gen, mask_gen=optim.mask_gen, step_gen=optim.step_gen, scale_way=None, xmin=-2.0, xmax=0.5, num_x=20, mode=mode, iteration=iter, iter_interval=iter_interval, loss_dir=result_dir) if sample_mask and is_mask_dict: plot_mask_result(mask_dict, sample_dir) return
def meta_optimize(self, cfg, meta_optim, data, model_cls, writer=None, mode='train'): assert mode in ['train', 'valid', 'test'] self.set_mode(mode) # mask_mode = ['no_mask', 'structured', 'unstructured'][2] # data_mode = ['in_train', 'in_test'][0] ############################################################################ analyze_model = False analyze_surface = False ############################################################################ result_dict = ResultDict() unroll_losses = 0 walltime = Walltime() test_kld = torch.tensor(0.) params = C(model_cls()).params self.feature_gen.new() self.step_gen.new() sparse_r = {} # sparsity iter_pbar = tqdm(range(1, cfg['iter_' + mode] + 1), 'Inner_loop') for iter in iter_pbar: debug_1 = sigint.is_active(iter == 1 or iter % 10 == 0) debug_2 = sigstp.is_active() with WalltimeChecker(walltime): model_train = C(model_cls(params.detach())) data_train = data['in_train'].load() train_nll, train_acc = model_train(*data_train) train_nll.backward() grad = model_train.params.grad.detach() # g = model_train.params.grad.flat.detach() # w = model_train.params.flat.detach() # step & mask genration feature, v_sqrt = self.feature_gen(grad.flat.detach()) # step = self.step_gen(feature, v_sqrt, debug=debug_1) # step = params.new_from_flat(step[0]) size = params.size().unflat() if cfg.mask_mode == 'structured': mask = self.mask_gen(feature, size, debug=debug_1) mask = ParamsFlattener(mask) mask_layout = mask.expand_as(params) params = params + grad.detach() * mask_layout * ( -cfg.inner_lr) elif cfg.mask_mode == 'unstructured': mask_flat = self.mask_gen.unstructured(feature, size) mask = params.new_from_flat(mask_flat) params = params + grad.detach() * mask * (-cfg.inner_lr) # update = params.new_from_flat(params.flat + grad.flat.detach() * mask * (-cfg.lr)) # params = params + update elif cfg.mask_mode == 'no_mask': params = params + grad.detach() * (-cfg.inner_lr) else: raise Exception('Unknown setting!') # import pdb; pdb.set_trace() # step_masked = step * mask_layout # params = params + step_masked with WalltimeChecker(walltime if mode == 'train' else None): model_test = C(model_cls(params)) if cfg.data_mode == 'in_train': data_test = data_train elif cfg.data_mode == 'in_test': data_test = data['in_test'].load() test_nll, test_acc = utils.isnan(*model_test(*data_test)) if debug_2: pdb.set_trace() if mode == 'train': unroll_losses += test_nll # + test_kld if iter % cfg.unroll == 0: meta_optim.zero_grad() unroll_losses.backward() nn.utils.clip_grad_value_(self.parameters(), 0.01) meta_optim.step() unroll_losses = 0 with WalltimeChecker(walltime): if not mode == 'train' or iter % cfg.unroll == 0: params = params.detach_() ########################################################################## if analyze_model: analyzers.model_analyzer(self, mode, model_train, params, model_cls, mask.tsize(0), data, iter, optim_it, analyze_mask=True, sample_mask=True, draw_loss=False) if analyze_surface: analyzers.surface_analyzer(params, best_mask, step, writer, iter) ########################################################################## result = dict( train_nll=train_nll.tolist(), test_nll=test_nll.tolist(), train_acc=train_acc.tolist(), test_acc=test_acc.tolist(), test_kld=test_kld.tolist(), walltime=walltime.time, ) if not cfg.mask_mode == 'no_mask': result.update( **mask.sparsity(overall=True), **mask.sparsity(overall=False), ) result_dict.append(result) log_pbar(result, iter_pbar) return result_dict, params
def meta_optimize(self, meta_optim, data, model_cls, optim_it, unroll, out_mul, k_obsrv=0, no_mask=False, writer=None, mode='train'): assert mode in ['train', 'valid', 'test'] self.set_mode(mode) ############################################################################ analyze_model = False analyze_surface = False ############################################################################ result_dict = ResultDict() unroll_losses = 0 walltime = Walltime() test_kld = torch.tensor(0.) params = C(model_cls()).params self.feature_gen.new() self.step_gen.new() sparse_r = {} # sparsity iter_pbar = tqdm(range(1, optim_it + 1), 'Inner_loop') for iter in iter_pbar: debug_1 = sigint.is_active(iter == 1 or iter % 10 == 0) debug_2 = sigstp.is_active() with WalltimeChecker(walltime): model_train = C(model_cls(params.detach())) data_ = data['in_train'].load() train_nll, train_acc = model_train(*data_) train_nll.backward() g = model_train.params.grad.flat.detach() w = model_train.params.flat.detach() # step & mask genration feature, v_sqrt = self.feature_gen(g) step = self.step_gen(feature, v_sqrt, debug=debug_1) step = params.new_from_flat(step[0]) size = params.size().unflat() if no_mask: params = params + step else: kld = self.mask_gen(feature, size, debug=debug_1) test_kld = kld / data['in_test'].full_size / unroll ## kl annealing function 'linear' / 'logistic' / None test_kld2 = test_kld * kl_anneal_function( anneal_function=None, step=iter, k=0.0025, x0=optim_it) mask = self.mask_gen.sample_mask() mask = ParamsFlattener(mask) mask_layout = mask.expand_as(params) import pdb pdb.set_trace() step_masked = step * mask_layout params = params + step_masked with WalltimeChecker(walltime if mode == 'train' else None): model_test = C(model_cls(params)) test_nll, test_acc = utils.isnan(*model_test( *data['in_test'].load())) if debug_2: pdb.set_trace() if mode == 'train': unroll_losses += test_nll # + test_kld if iter % unroll == 0: meta_optim.zero_grad() unroll_losses.backward() nn.utils.clip_grad_value_(self.parameters(), 0.01) meta_optim.step() unroll_losses = 0 with WalltimeChecker(walltime): if not mode == 'train' or iter % unroll == 0: params = params.detach_() ########################################################################## if analyze_model: analyzers.model_analyzer(self, mode, model_train, params, model_cls, mask.tsize(0), data, iter, optim_it, analyze_mask=True, sample_mask=True, draw_loss=False) if analyze_surface: analyzers.surface_analyzer(params, best_mask, step, writer, iter) ########################################################################## result = dict( train_nll=train_nll.tolist(), test_nll=test_nll.tolist(), train_acc=train_acc.tolist(), test_acc=test_acc.tolist(), test_kld=test_kld.tolist(), walltime=walltime.time, ) if no_mask is False: result.update( **mask.sparsity(overall=True), **mask.sparsity(overall=False), ) result_dict.append(result) log_pbar(result, iter_pbar) return result_dict, params
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): data = data_cls(training=should_train) model = C(model_cls()) g_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) if should_train: self.train() else: self.eval() result_dict = ResultDict() model_dt = None grad_losses = 0 train_with_lazy_tf = True eval_test_grad_loss = True grad_prev = C(torch.zeros(model.params.size().flat())) update_prev = C(torch.zeros(model.params.size().flat())) #grad_real_prev = None grad_real_cur = None grad_real_prev = C(torch.zeros(model.params.size().flat())) iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') """grad_real_prev: teacher-forcing current input of RNN. required during all of training steps + every 'k' test steps. grad_real_cur: MSE loss target for current output of RNN. required during all the training steps only. """ batch_arg = BatchManagerArgument( params=model.params, grad_prev=grad_prev, update_prev=update_prev, g_states=StatesSlicingWrapper(g_states), mse_loss=C(torch.zeros(1)), sigma=C(torch.zeros(model.params.size().flat())), #updates=OptimizerBatchManager.placeholder(model.params.flat.size()), ) for iteration in iter_pbar: data_ = data.sample() # if grad_real_cur is not None: # if iteration % self.k == 0 # grad_cur_prev = grad_real_cur # back up previous grad if (iteration + 1) % self.k == 0 or should_train or eval_test_grad_loss: # get grad every step while training # get grad one step earlier than teacher-forcing steps while testing model_dt = C(model_cls(params=model.params.detach())) model_loss_dt = model_dt(*data_) # model.params.register_grad_cut_hook_() model_loss_dt.backward() # model.params.remove_grad_cut_hook_() grad_real_cur = model_dt.params.flat.grad assert grad_real_cur is not None if should_train or eval_test_grad_loss: # mse loss target # grad_real_cur is not used as a mse target while testing assert grad_real_cur is not None batch_arg.update( BatchManagerArgument( grad_real_cur=grad_real_cur).immutable()) if iteration % self.k == 0 or (should_train and not train_with_lazy_tf): # teacher-forcing assert grad_real_prev is not None batch_arg.update( BatchManagerArgument( grad_real_prev=grad_real_prev).immutable().volatile()) grad_real_prev = None ######################################################################## batch_manager = OptimizerBatchManager(batch_arg=batch_arg) for get, set in batch_manager.iterator(): if iteration % self.k == 0 or (should_train and not train_with_lazy_tf): grad_prev = get.grad_real_prev # teacher-forcing else: grad_prev = get.grad_prev # free-running grad_cur, g_states = self(grad_prev.detach(), get.update_prev.detach(), get.params.detach(), get.g_states) set.grad_prev(grad_cur) set.g_states(g_states) set.params(get.params - 1.0 * grad_cur.detach()) set.sigma(self.sigma) if should_train or eval_test_grad_loss: set.mse_loss( self.mse_loss(grad_cur, get.grad_real_cur.detach())) ########################################################################## batch_arg = batch_manager.batch_arg_to_set grad_real = batch_manager.batch_arg_to_get.grad_real_cur.detach() cosim_loss = nn.functional.cosine_similarity(batch_arg.grad_prev, grad_real, dim=0) grad_loss = batch_arg.mse_loss * 100 cosim_loss = -torch.log((cosim_loss + 1) * 0.5) * 100 grad_losses += (grad_loss + cosim_loss) #to_float = lambda x: float(x.data.cpu().numpy()) sigma = batch_arg.sigma.detach().mean() iter_pbar.set_description( 'optim_iteration[loss(model):{:10.6f}/(mse):{:10.6f}/(cosim):{:10.6f}/(sigma):{:10.6f}]' ''.format(model_loss_dt.tolist(), grad_loss.tolist()[0], cosim_loss.tolist()[0], sigma.tolist())) if should_train and iteration % unroll == 0: w_decay_loss = 0 for p in self.parameters(): w_decay_loss += torch.sum(p * p) w_decay_loss = (w_decay_loss * 0.5) * 1e-4 total_loss = grad_losses + w_decay_loss meta_optimizer.zero_grad() total_loss.backward() meta_optimizer.step() if not should_train or iteration % unroll == 0: model_losses = 0 grad_losses = 0 batch_arg.detach_() # batch_arg.params.detach_() # batch_arg.g_states.detach_() # batch_arg.u_states.detach_() result_dict.append( step_num=iteration, loss=model_loss_dt, grad_loss=grad_loss, ) if not should_train: result_dict.append(walltime=iter_watch.touch(), **self.params_tracker( grad_real=grad_real_cur, grad_pred=batch_arg.grad_prev, update=batch_arg.update_prev, sigma=batch_arg.sigma, )) if (iteration + 1) % self.k == 0 or should_train or eval_test_grad_loss: grad_real_prev = grad_real_cur grad_real_cur = None return result_dict
def meta_optimize(self, meta_optim, data, model_cls, optim_it, unroll, out_mul, k_obsrv=1, no_mask=False, writer=None, mode='train'): assert mode in ['train', 'valid', 'test'] if no_mask is True: raise Exception( "this module currently does NOT suport no_mask option") self.set_mode(mode) ############################################################################ n_samples = k_obsrv """MSG: better postfix?""" analyze_model = False analyze_surface = False ############################################################################ if analyze_surface: writer.new_subdirs('best', 'any', 'rand', 'inv', 'dense') result_dict = ResultDict() unroll_losses = 0 walltime = Walltime() test_kld = torch.tensor(0.) params = C(model_cls()).params self.feature_gen.new() self.step_gen.new() iter_pbar = tqdm(range(1, optim_it + 1), 'Inner_loop') set_size = {'layer_0': 500, 'layer_1': 10} # NOTE: make it smarter for iter in iter_pbar: debug_1 = sigint.is_active(iter == 1 or iter % 10 == 0) debug_2 = sigstp.is_active() best_loss = 9999999 best_params = None with WalltimeChecker(walltime): model_train = C(model_cls(params.detach())) train_nll, train_acc = model_train(*data['in_train'].load()) train_nll.backward() g = model_train.params.grad.flat.detach() w = model_train.params.flat.detach() feature, v_sqrt = self.feature_gen(g) size = params.size().unflat() kld = self.mask_gen(feature, size) losses = [] lips = [] valid_mask_patience = 100 assert n_samples > 0 """FIX LATER: when n_samples == 0 it can behave like no_mask flag is on.""" for i in range(n_samples): # step & mask genration for j in range(valid_mask_patience): mask = self.mask_gen.sample_mask() mask = ParamsFlattener(mask) if mask.is_valid_sparsity(): if j > 0: print( f'\n\n[!]Resampled {j + 1} times to get valid mask!' ) break if j == valid_mask_patience - 1: raise Exception( "[!]Could not sample valid mask for " f"{j+1} trials.") step_out = self.step_gen(feature, v_sqrt, debug=debug_1) step = params.new_from_flat(step_out[0]) mask_layout = mask.expand_as(params) step_sparse = step * mask_layout params_sparse = params + step_sparse params_pruned = params_sparse.prune(mask > 0.5) if params_pruned.size().unflat()['mat_0'][1] == 0: continue # cand_loss = model(*outer_data_s) sparse_model = C(model_cls(params_pruned.detach())) loss, _ = sparse_model(*data['in_train'].load()) if (loss < best_loss) or i == 0: best_loss = loss best_params = params_sparse best_pruned = params_pruned best_mask = mask if best_params is not None: params = best_params with WalltimeChecker(walltime if mode == 'train' else None): model_test = C(model_cls(params)) test_nll, test_acc = utils.isnan(*model_test( *data['in_test'].load())) test_kld = kld / data['in_test'].full_size / unroll ## kl annealing function 'linear' / 'logistic' / None test_kld2 = test_kld * kl_anneal_function( anneal_function=None, step=iter, k=0.0025, x0=optim_it) total_test = test_nll + test_kld2 if mode == 'train': unroll_losses += total_test if iter % unroll == 0: meta_optim.zero_grad() unroll_losses.backward() nn.utils.clip_grad_value_(self.parameters(), 0.01) meta_optim.step() unroll_losses = 0 with WalltimeChecker(walltime): if not mode == 'train' or iter % unroll == 0: params = params.detach_() ########################################################################## """Analyzers""" if analyze_model: analyzers.model_analyzer(self, mode, model_train, params, model_cls, set_size, data, iter, optim_it, analyze_mask=True, sample_mask=True, draw_loss=False) if analyze_surface: analyzers.surface_analyzer(params, best_mask, step, writer, iter) ########################################################################## result = dict( train_nll=train_nll.tolist(), test_nll=test_nll.tolist(), train_acc=train_acc.tolist(), test_acc=test_acc.tolist(), test_kld=test_kld.tolist(), walltime=walltime.time, **best_mask.sparsity(overall=True), **best_mask.sparsity(overall=False), ) result_dict.append(result) log_pbar(result, iter_pbar) return result_dict, params
def loop(mode, data, outer_steps, inner_steps, log_steps, fig_epochs, inner_lr, log_mask=True, unroll_steps=None, meta_batchsize=0, sampler=None, epoch=1, outer_optim=None, save_path=None, is_RL=False): ##################################################################### """Args: meta_batchsize(int): If meta_batchsize |m| > 0, gradients for multiple unrollings from each episodes size of |m| will be accumulated in sequence but updated all at once. (Which can be done in parallel when VRAM is large enough, but will be simulated in this code.) If meta_batchsize |m| = 0(default), then update will be performed after each unrollings. """ assert mode in ['train', 'valid', 'test'] assert meta_batchsize >= 0 print(f'Start_of_{mode}.') if mode == 'train' and unroll_steps is None: raise Exception("unroll_steps has to be specied when mode='train'.") if mode != 'train' and unroll_steps is not None: raise Warning("unroll_steps has no effect when mode mode!='train'.") train = True if mode == 'train' else False force_base = True # TODO: to gin configuration fc_pulling = False # does not support for now class_balanced = False if class_balanced: classes_per_episode = 10 samples_per_class = 3 else: batch_size = 128 sample_split_ratio = 0.5 # sample_split_ratio = None anneal_outer_steps = 50 concrete_resample = True detach_param = False split_method = {1: 'inclusive', 2: 'exclusive'}[1] sampler_type = {1: 'pre_sampler', 2: 'post_sampler'}[2] ##################################################################### mask_unit = { 0: MaskUnit.SAMPLE, 1: MaskUnit.CLASS, }[0] mask_dist = { 0: MaskDist.SOFT, 1: MaskDist.DISCRETE, 2: MaskDist.CONCRETE, 3: MaskDist.RL, }[3 if is_RL else 2] mask_mode = MaskMode(mask_unit, mask_dist) ##################################################################### # scheduler inner_step_scheduler = InnerStepScheduler( outer_steps, inner_steps, anneal_outer_steps) # 100 classes in total if split_method == 'exclusive': # meta_support, remainder = data.split_classes(0.1) # 10 classes # meta_query = remainder.sample_classes(50) # 50 classes meta_support, meta_query = data.split_classes(1 / 5) # 1(100) : 4(400) # meta_support, meta_query = data.split_classes(0.3) # 30 : 70 classes elif split_method == 'inclusive': # subdata = data.sample_classes(10) # 50 classes meta_support, meta_query = data.split_instances(0.5) # 5:5 instances else: raise Exception() if train: if meta_batchsize > 0: # serial processing of meta-minibatch update_epochs = meta_batchsize update_steps = None else: # update at every unrollings update_steps = unroll_steps update_epochs = None assert (update_epochs is None) != (update_steps is None) # for result recordin result_frame = ResultFrame() if save_path: writer = SummaryWriter(os.path.join(save_path, 'tfevent')) ################################## if is_RL: env = Environment() policy = Policy(sampler) neg_rw = None ################################## for i in range(1, outer_steps + 1): outer_loss = 0 result_dict = ResultDict() # initialize sampler sampler.initialize() # initialize base learner model_q = Model(len(meta_support), mode='metric') if fc_pulling: model_s = Model(len(meta_support), mode='fc') params = C(model_s.get_init_params('ours')) else: model_s = model_q params = C(model_s.get_init_params('ours')) ################################## if is_RL: policy.reset() # env.reset() ################################## if not train or force_base: """baseline 0: naive single task learning baseline 1: single task learning with the same loss scale """ # baseline parameters params_b0 = C(params.copy('b0'), 4) params_b1 = C(params.copy('b1'), 4) if class_balanced: batch_sampler = BalancedBatchSampler.get(class_size, sample_size) else: batch_sampler = RandomBatchSampler.get(batch_size) # episode iterator episode_iterator = MetaEpisodeIterator( meta_support=meta_support, meta_query=meta_query, batch_sampler=batch_sampler, match_inner_classes=match_inner_classes, inner_steps=inner_step_scheduler(i, verbose=True), sample_split_ratio=sample_split_ratio, num_workers=4, pin_memory=True, ) do_sample = True for k, (meta_s, meta_q) in episode_iterator(do_sample): outs = [] meta_s = meta_s.cuda() ##################################################################### if do_sample: do_sample = False # task encoding (very first step and right after a meta-update) with torch.set_grad_enabled(train): def feature_fn(): # determein whether to untilize model features if sampler_type == 'post_sampler': return model_s(meta_s, params) elif sampler_type == 'pre_sampler': return None # sampler do the work! mask, lr_ = sampler(mask_mode, feature_fn) # out_for_sampler = model_s(meta_s, params, debug=do_sample) # mask_, lr = sampler( # pairwise_dist=out_for_sampler.pairwise_dist, # classwise_loss=out_for_sampler.loss, # classwise_acc=out_for_sampler.acc, # n_classes=out_for_sampler.n_classes, # mask_mode=mask_mode, # ) # sample from concrete distribution # while keeping original distribution for simple resampling if mask_mode.dist is MaskDist.CONCRETE: mask = mask_.rsample() else: mask = mask_ if is_RL: mask, neg_rw = policy.predict(mask) if neg_rw: # TODO fix iteratively [0 class]s are sampled. do_sample = True print('[!] select 0 class!') policy.rewards.append(neg_rw) import pdb pdb.set_trace() # policy.update() continue else: # we can simply resample masks when using concrete distribution # without seriously going through the sampler if mask_mode.dist is MaskDist.CONCRETE and concrete_resample: mask = mask_.rsample() ##################################################################### # use learned learning rate if available # lr = inner_lr if lr is None else lr # inner_lr: preset / lr: learned # train on support set params, mask = C([params, mask], 2) ##################################################################### out_s = model_s(meta_s, params, mask=mask, mask_mode=mask_mode) # out_s_loss_masked = out_s.loss_masked outs.append(out_s) # inner gradient step out_s_loss_mean = out_s.loss.mean() if mask_mode == MaskMode.RL \ else out_s.loss_masked_mean out_s_loss_mean, lr = C([out_s_loss_mean, lr], 2) params = params.sgd_step( out_s_loss_mean, lr, second_order=False if is_RL else True, detach_param=True if is_RL else detach_param ) ##################################################################### # baseline if not train or force_base: out_s_b0 = model_s(meta_s, params_b0) out_s_b1 = model_s(meta_s, params_b1) outs.extend([out_s_b0, out_s_b1]) # attach mask to get loss_s out_s_b1.attach_mask(mask) # inner gradient step (baseline) params_b0 = params_b0.sgd_step(out_s_b0.loss.mean(), inner_lr) params_b1 = params_b1.sgd_step(out_s_b1.loss_scaled_mean, lr.detach()) meta_s = meta_s.cpu() meta_q = meta_q.cuda() # test on query set with torch.set_grad_enabled(train): params = C(params, 3) out_q = model_q(meta_q, params) outs.append(out_q) # baseline if not train or force_base: with torch.no_grad(): # test on query set out_q_b0 = model_q(meta_q, params_b0) out_q_b1 = model_q(meta_q, params_b1) outs.extend([out_q_b0, out_q_b1]) ################################## if is_RL: reward = env.step(outs) policy.rewards.append(reward) outs.append(policy) ################################## # record result result_dict.append( outer_step=epoch * i, inner_step=k, **ModelOutput.as_merged_dict(outs) ) # append to the dataframe result_frame = result_frame.append_dict( result_dict.index_all(-1).mean_all(-1)) # logging if k % log_steps == 0: logger.step_info( epoch, mode, i, outer_steps, k, inner_step_scheduler(i), lr) logger.split_info(meta_support, meta_query, episode_iterator) logger.colorized_mask( mask, fmt="2d", vis_num=20, cond=not sig_1.is_active() and log_mask) logger.outputs(outs, print_conf=sig_1.is_active()) logger.flush() # to debug in the middle of running process. if k == inner_steps and sig_2.is_active(): import pdb pdb.set_trace() # compute outer gradient if train and (k % unroll_steps == 0 or k == inner_steps): ##################################################################### if not is_RL: outer_loss += out_q.loss.mean() outer_loss.backward() outer_loss = 0 params.detach_().requires_grad_() mask = mask.detach() lr = lr.detach() do_sample = True ##################################################################### if not train: if not is_RL: # when params is not leaf node created by user, # requires_grad attribute is False by default. params.requires_grad_() # meta(outer) learning if train and update_steps and k % update_steps == 0: # when you have meta_batchsize == 0, update_steps == unroll_steps if is_RL: policy.update() else: outer_optim.step() sampler.zero_grad() ### end of inner steps (k) ### if train and update_epochs and i % update_epochs == 0: # when you have meta_batchsize > 0, update_epochs == meta_batchsize if is_RL: policy.update() else: outer_optim.step() sampler.zero_grad() if update_epochs: print(f'Meta-batchsize is {meta_batchsize}: Sampler updated.') if train and update_steps: print(f'Meta-batchsize is zero. Updating after every unrollings.') # tensorboard if save_path and train: step = (epoch * (outer_steps - 1)) + i res = ResultFrame(result_frame[result_frame['outer_step'] == i]) loss = res.get_best_loss().mean() acc = res.get_best_acc().mean() writer.add_scalars( 'Loss/train', {n: loss[n] for n in loss.index}, step) writer.add_scalars('Acc/train', {n: acc[n] for n in acc.index}, step) # dump figures if save_path and i % fig_epochs == 0: # meta_s.s.save_fig(f'imgs/meta_support', save_path, i) # meta_q.s.save_fig(f'imgs/meta_query', save_path, i) result_dict['ours_s_mask'].save_fig(f'imgs/masks', save_path, i) result_dict.get_items( ['ours_s_mask', 'ours_s_loss', 'ours_s_loss_masked', 'b0_s_loss', 'b1_s_loss', 'ours_q_loss', 'b0_q_loss', 'b1_q_loss'] ).save_csv(f'classwise/{mode}', save_path, i) # distinguishable episodes if not i == outer_steps: print(f'Path for saving: {save_path}') print(f'End_of_episode: {i}') ### end of episode (i) ### print(f'End_of_{mode}.') # del metadata return sampler, result_frame
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): data = data_cls(training=should_train) model = C(model_cls()) optimizer_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) if should_train: self.train() else: self.eval() result_dict = ResultDict() unroll_losses = 0 update = C(torch.zeros(model.params.size().flat)) iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') batch_arg = BatchManagerArgument( params=model.params, states=StatesSlicingWrapper(optimizer_states), updates=update, ) for iteration in iter_pbar: data_ = data.sample() loss = model(*data_) unroll_losses += loss model_dt = C(model_cls(params=batch_arg.params.detach())) model_dt_loss = model_dt(*data_) assert loss == model_dt_loss model_dt_loss.backward() assert model_dt.params.flat.grad is not None grad = model_dt.params.flat.grad batch_arg.update( BatchManagerArgument(grad=grad).immutable().volatile()) iter_pbar.set_description(f'optim_iteration[loss:{loss}]') ########################################################################## batch_manager = OptimizerBatchManager(batch_arg=batch_arg) for get, set in batch_manager.iterator(): updates, new_states = self(get.grad.detach(), get.states) set.states(new_states) set.params(get.params + updates * out_mul) if not should_train: set.updates(updates) ########################################################################## batch_arg = batch_manager.batch_arg_to_set if should_train and iteration % unroll == 0: meta_optimizer.zero_grad() unroll_losses.backward() meta_optimizer.step() if not should_train or iteration % unroll == 0: unroll_losses = 0 batch_arg.detach_() # model.params = model.params.detach() # optimizer_states = optimizer_states.detach() model = C(model_cls(params=batch_arg.params)) result_dict.append(loss=loss) if not should_train: result_dict.append(walltime=iter_watch.touch(), **self.params_tracker( grad=grad, update=batch_arg.updates, )) return result_dict
def test_normal(name, cfg): """function for test of static optimizers.""" cfg = Config.merge( [cfg.problem, cfg.args, cfg.neural_optimizers[name].test_args]) data_cls = _get_attr_by_name(cfg.problem.data_cls) model_cls = _get_attr_by_name(cfg.problem.model_cls) optim_cls = _get_attr_by_name(cfg.problem.optim_cls) writer = TFWriter(cfg.save_dir, name) if cfg.save_dir else None tests_pbar = tqdm(range(cfg.n_test), 'outer_test') result_all = ResultDict() result_final = ResultDict() # test loss & acc for the whole testset best_test_loss = 999999 # params_tracker = ParamsIndexTracker(n_tracks=10) # Outer loop (n_test) for j in tests_pbar: data = data_cls().sample_meta_test() model = C(model_cls()) optimizer = optim_cls(model.parameters(), **cfg.optim_args) walltime = Walltime() result = ResultDict() iter_pbar = tqdm(range(cfg.iter_test), 'inner_train') # Inner loop (iter_test) for k in iter_pbar: with WalltimeChecker(walltime): # Inner-training loss train_nll, train_acc = model(*data['in_train'].load()) optimizer.zero_grad() train_nll.backward() # before = model.params.detach('hard').flat optimizer.step() # Inner-test loss test_nll, test_acc = model(*data['in_test'].load()) # update = model.params.detach('hard').flat - before # grad = model.params.get_flat().grad result_ = dict( train_nll=train_nll.tolist(), test_nll=test_nll.tolist(), train_acc=train_acc.tolist(), test_acc=test_acc.tolist(), walltime=walltime.time, ) log_pbar(result_, iter_pbar) result.append(result_) if cfg.save_dir: save_figure(name, cfg.save_dir, writer, result, j, 'normal') result_all.append(result) result_final.append( final_inner_test(model, data['in_test'], mode='test')) result_final_mean = result_final.mean() last_test_loss = result_final_mean['final_loss_mean'] last_test_acc = result_final_mean['final_acc_mean'] if last_test_loss < best_test_loss: best_test_loss = last_test_loss best_test_acc = last_test_acc result_test = dict( best_test_loss=best_test_loss, best_test_acc=best_test_acc, last_test_loss=last_test_loss, last_test_acc=last_test_acc, ) log_pbar(result_test, tests_pbar) # TF-events for inner loop (train & test) mean_j = result_all.mean(0) if cfg.save_dir: for i in range(cfg.iter_test): mean = mean_j.getitem(i) walltime = mean_j['walltime'][i] log_tf_event(writer, 'meta_test_inner', mean, i, walltime) result_all.save(name, cfg.save_dir) result_all.save_as_csv(name, cfg.save_dir, 'meta_test_inner') result_final.save_as_csv(name, cfg.save_dir, 'meta_test_final', trans_1d=True) return result_all
def meta_optimize(self, meta_optim, data, model_cls, optim_it, unroll, out_mul, k_obsrv=None, no_mask=None, writer=None, mode='train'): """FIX LATER: do something about the dummy arguments.""" assert mode in ['train', 'valid', 'test'] self.set_mode(mode) use_indexer = False params = C(model_cls()).params states = C( OptimizerStates.initial_zeros(size=(self.n_layers, len(params), self.hidden_sz), rnn_cell=self.rnn_cell)) result_dict = ResultDict() unroll_losses = 0 walltime = Walltime() update = C(torch.zeros(params.size().flat())) if use_indexer: batch_arg = BatchManagerArgument( params=params, states=StatesSlicingWrapper(states), updates=update, ) iter_pbar = tqdm(range(1, optim_it + 1), 'inner_train') for iter in iter_pbar: with WalltimeChecker(walltime): model_train = C(model_cls(params=params.detach())) train_nll, train_acc = model_train(*data['in_train'].load()) train_nll.backward() if model_train.params.flat.grad is not None: g = model_train.params.flat.grad else: g = model_train._grad2params(model_train.params) assert g is not None if use_indexer: # This indexer was originally for outer-level batch split # in case that the whold parameters do not fit in the VRAM. # Which is not good, unless unavoidable, since you cannot avoid # additional time cost induced by the serial computation. batch_arg.update( BatchManagerArgument(grad=g).immutable().volatile()) indexer = DefaultIndexer(input=g, shuffle=False) batch_manager = OptimizerBatchManager(batch_arg=batch_arg, indexer=indexer) ###################################################################### for get, set in batch_manager.iterator(): updates, new_states = self(get.grad().detach(), get.states()) set.states(new_states) set.params(get.params() + updates * out_mul) if not mode == 'train': set.updates(updates) ###################################################################### batch_arg = batch_manager.batch_arg_to_set params = batch_arg.params else: updates, states = self(g.detach(), states) updates = params.new_from_flat(updates) params = params + updates * out_mul #params = params - params.new_from_flat(g) * 0.1 with WalltimeChecker(walltime if mode == 'train' else None): model_test = C(model_cls(params=params)) test_nll, test_acc = model_test(*data['in_test'].load()) if mode == 'train': unroll_losses += test_nll if iter % unroll == 0: meta_optim.zero_grad() unroll_losses.backward() nn.utils.clip_grad_value_(self.parameters(), 0.01) meta_optim.step() unroll_losses = 0 with WalltimeChecker(walltime): if not mode == 'train' or iter % unroll == 0: if use_indexer: batch_arg.detach_() else: params.detach_() states.detach_() # model.params = model.params.detach() # states = states.detach() result = dict( train_nll=train_nll.tolist(), test_nll=test_nll.tolist(), train_acc=train_acc.tolist(), test_acc=test_acc.tolist(), walltime=walltime.time, ) result_dict.append(result) log_pbar(result, iter_pbar) return result_dict, params
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): data = data_cls(training=should_train) model = C(model_cls()) optimizer_states = C(OptimizerStates.initial_zeros( size=(self.n_layers, len(model.params), self.hidden_sz), rnn_cell=self.rnn_cell)) if should_train: self.train() else: self.eval() result_dict = ResultDict() walltime = 0 unroll_losses = 0 update = C(torch.zeros(model.params.size().flat())) batch_arg = BatchManagerArgument( params=model.params, states=StatesSlicingWrapper(optimizer_states), updates=update, ) iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') for iteration in iter_pbar: iter_watch.touch() # entire_loop.touch() # interval.touch() data_ = data.sample() if should_train: model = C(model_cls(params=batch_arg.params)) loss = model(*data_) unroll_losses += loss # print("\n[1] : "+str(interval.touch('interval', False))) # model_ff.touch() model_detached = C(model_cls(params=batch_arg.params.detach())) loss_detached = model_detached(*data_) # model_ff.touch('interval', True) if should_train: assert loss == loss_detached # model_bp.touch() loss_detached.backward() # model_bp.touch('interval', True) # interval.touch() assert model_detached.params.flat.grad is not None grad = model_detached.params.flat.grad batch_arg.update(BatchManagerArgument(grad=grad).immutable().volatile()) iter_pbar.set_description(f'optim_iteration[loss:{loss_detached}]') # print("\n[2-1] : "+str(interval.touch('interval', False))) ########################################################################## # indexer = DefaultIndexer(input=grad, shuffle=False) indexer = TopKIndexer(input=grad, k_ratio=0.1, random_k_mode=False)# bound=model_dt.params.size().bound_layerwise()) #indexer = DefaultIndexer(input=grad) batch_manager = OptimizerBatchManager( batch_arg=batch_arg, indexer=indexer) # print("\n[2-2] : "+str(interval.touch('interval', False))) # optim_ff.touch() # optim_ff_1.touch() for get, set in batch_manager.iterator(): # optim_ff_1.touch('interval', True) # optim_ff_2.touch() # optim_ff_2.touch('interval', True) # optim_ff_3.touch() updates, new_states = self(get.grad().detach(), get.states()) # optim_ff_3.touch('interval', True) # optim_ff_4.touch() updates = updates * out_mul set.states(new_states) set.params(get.params() + updates) if not should_train: set.updates(updates) # optim_ff_4.touch('interval', True) # optim_ff.touch('interval', True) ########################################################################## # interval.touch() batch_arg = batch_manager.batch_arg_to_set # print("\n[3] : "+str(interval.touch('interval', False))) if should_train and iteration % unroll == 0: # optim_bp.touch() meta_optimizer.zero_grad() unroll_losses.backward() nn.utils.clip_grad_norm_(self.parameters(), 10) meta_optimizer.step() unroll_losses = 0 # optim_bp.touch('interval', True) # interval.touch() if not should_train or iteration % unroll == 0: batch_arg.detach_() # print("\n[4] : "+str(interval.touch('interval', False))) # interval.touch() # model.params = model.params.detach() # optimizer_states = optimizer_states.detach() walltime += iter_watch.touch('interval') result_dict.append(loss=loss_detached) # print("\n[5] : "+str(interval.touch('interval', False))) if not should_train: result_dict.append( walltime=iter_watch.touch(), **self.params_tracker( grad=grad, update=batch_arg.updates, ) ) # entire_loop.touch('interval', True) return result_dict
def meta_optimize(self, meta_optimizer, data, model_cls, optim_it, unroll, out_mul, mode): assert mode in ['train', 'valid', 'test'] if mode == 'train': self.train() # data.new_train_data() inner_data = data.loaders['inner_train'] outer_data = data.loaders['inner_valid'] # inner_data = data.loaders['train'] # outer_data = inner_data drop_mode = 'soft_drop' elif mode == 'valid': self.eval() inner_data = data.loaders['valid'] outer_data = None drop_mode = 'hard_drop' elif mode == 'eval': self.eval() inner_data = data.loaders['test'] outer_data = None drop_mode = 'hard_drop' # data = dataset(mode=mode) model = C(model_cls(sb_mode=self.sb_mode)) model(*data.pseudo_sample()) # layer_sz = sum([v for v in model.params.size().unflat(1, 'mat').values()]) mask_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.activations.mean(0)), self.hidden_sz))) update_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) result_dict = ResultDict() unroll_losses = 0 walltime = 0 update = C(torch.zeros(model.params.size().flat())) batch_arg = BatchManagerArgument( params=model.params, states=StatesSlicingWrapper(update_states), updates=model.params.new_zeros(model.params.size()), ) iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') bp_watch = utils.StopWatch('bp') loss_decay = 1.0 for iteration in iter_pbar: iter_watch.touch() model_detached = C( model_cls(params=batch_arg.params.detach(), sb_mode=self.sb_mode)) loss_detached = model_detached(*inner_data.load()) loss_detached.backward() # import pdb; pdb.set_trace() iter_pbar.set_description(f'optim_iteration[loss:{loss_detached}]') # np.set_printoptions(suppress=True, precision=3, linewidth=200, edgeitems=20) if debug_sigint.signal_on: debug = iteration == 1 or iteration % 10 == 0 # debug = True else: debug = False backprop_mask, mask_states, sparsity_loss = self.mask_generator( activations=model_detached.activations.grad.detach(), debug=debug, states=mask_states) # if debug_sigstp.signal_on and (iteration == 1 or iteration % 10 == 0): # mask_list.append(backprop_mask) # import pdb; pdb.set_trace() # min = 0.001 # max = 0.01 unroll_losses += sparsity_loss * 0.01 * \ loss_decay # * ((max - min) * lambda_ + min) # import pdb; pdb.set_trace() # model_detached.apply_backprop_mask(backprop_mask.unflat, drop_mode) # loss_detached.backward() assert model_detached.params.grad is not None if mode == 'train': model = C( model_cls(params=batch_arg.params, sb_mode=self.sb_mode)) loss = model(*outer_data.load()) # assert loss == loss_detached if iteration == 1: initial_loss = loss loss_decay = 1.0 else: loss_decay = (loss / initial_loss).detach() unroll_losses += loss # model_detached.apply_backprop_mask(backprop_mask.unflat, drop_mode) #model_detached.params.apply_grad_mask(mask, 'soft_drop') coord_mask = backprop_mask.expand_as_params(model_detached.params) if drop_mode == 'soft_drop': # grad = model_detached.params.grad.mul(coord_mask).flat # model_detached.params = model_detached.params.mul(params_mask) index = None # updates, states = self.updater(grad, states, mask=coord_mask) # updates = updates * out_mul * coord_mask # import pdb; pdb.set_trace() elif drop_mode == 'hard_drop': index = (coord_mask.flat.squeeze() > 0.5).nonzero().squeeze() grad = model_detached.params.grad.flat batch_arg.update(BatchManagerArgument(grad=grad, mask=coord_mask)) ########################################################################## indexer = DefaultIndexer(input=grad, index=index, shuffle=False) batch_manager = OptimizerBatchManager(batch_arg=batch_arg, indexer=indexer) for get, set in batch_manager.iterator(): inp = get.grad().detach() #inp = [get.grad().detach(), get.tag().detach()] if drop_mode == 'soft_drop': updates, new_states = self.updater( inp, get.states()) # , mask=get.mask()) # if iteration == 100: # import pdb; pdb.set_trace() # aaa = get.states() *2 updates = updates * out_mul * get.mask() set.states(get.states() * (1 - get.mask()) + new_states * get.mask()) # set.states(new_states) elif drop_mode == 'hard_drop': updates, new_states = self.updater(inp, get.states()) updates = updates * out_mul set.states(new_states) else: raise RuntimeError('Unknown drop mode!') set.params(get.params() + updates) # if not mode != 'train': # set.updates(updates) ########################################################################## batch_arg = batch_manager.batch_arg_to_set # updates = model.params.new_from_flat(updates) if mode == 'train' and iteration % unroll == 0: meta_optimizer.zero_grad() unroll_losses.backward() nn.utils.clip_grad_norm_(self.parameters(), 5) meta_optimizer.step() unroll_losses = 0 # if iteration == 1 or iteration % 10 == 0: # import pdb; pdb.set_trace() if not mode == 'train' or iteration % unroll == 0: mask_states.detach_() batch_arg.detach_() # model.params = model.params.detach() # update_states = update_states.detach() walltime += iter_watch.touch('interval') result_dict.append(loss=loss_detached) if not mode == 'train': result_dict.append( walltime=walltime, # **self.params_tracker( # grad=grad, # update=batch_arg.updates, # ) ) return result_dict
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): target = data_cls(training=should_train) optimizee = C(model_cls()) optimizer_states = OptimizerStates.zero_states( self.n_layers, len(optimizee.params), self.hidden_sz) if should_train: self.train() else: self.eval() coordinate_tracker = ParamsIndexTracker( n_params=len(optimizee.params), n_tracks=30) result_dict = ResultDict() unroll_losses = 0 iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') for iteration in iter_pbar: loss = optimizee(target) unroll_losses += loss loss.backward(retain_graph=should_train) iter_pbar.set_description(f'optim_iteration[loss:{loss}]') unit_size = torch.Size([len(optimizee.params), 1]) grid_size = torch.Size([len(optimizee.params), self.n_coord]) batch_manager = OptimizerBatchManager( params=optimizee.params, states=StatesSlicingWrapper(optimizer_states), units=OptimizerBatchManager.placeholder(unit_size), grid_abs=OptimizerBatchManager.placeholder(grid_size), grid_rel=OptimizerBatchManager.placeholder(grid_size), batch_size=30000, ) ########################################################################## for get, set in batch_manager.iterator(): units, new_states = self.unit_maker( C.detach(get.params.grad), get.states.clone()) units = units * out_mul / self.n_coord / 2 grid_abs, grid_rel = self.grid_maker(units, get.params) set.states(new_states) set.units(units) set.grid_abs(grid_abs) set.grid_rel(grid_rel) ########################################################################## landscape = self.loss_observer(batch_manager.grid_abs, model_cls, target) updates = self.step_maker(batch_manager.grid_rel, landscape) optimizee.params.add_flat_(updates) if should_train and iteration % unroll == 0: meta_optimizer.zero_grad() unroll_losses.backward() meta_optimizer.step() if not should_train or iteration % unroll == 0: unroll_losses = 0 optimizee.params = optimizee.params.detach() optimizer_states = optimizer_states.detach() optimizee = C(model_cls(optimizee.params)) result_dict.append(loss=loss) if not should_train: result_dict.append( grid=coordinate_tracker(batch_manager.units), update=coordinate_tracker(updates), walltime=iter_watch.touch(), ) return result_dict
def meta_optimize(self, meta_optimizer, data, model_cls, optim_it, unroll, out_mul, mode): assert mode in ['train', 'valid', 'test'] if mode == 'train': self.train() # data.new_train_data() inner_data = data.loaders['inner_train'] outer_data = data.loaders['inner_valid'] # inner_data = data.loaders['train'] # outer_data = inner_data drop_mode = 'soft_drop' elif mode == 'valid': self.eval() inner_data = data.loaders['valid'] outer_data = None drop_mode = 'hard_drop' elif mode == 'test': self.eval() inner_data = data.loaders['test'] outer_data = None drop_mode = 'hard_drop' # data = dataset(mode=mode) model = C(model_cls(sb_mode=self.sb_mode)) model(*data.pseudo_sample()) params = model.params result_dict = ResultDict() unroll_losses = 0 walltime = 0 self.feature_gen.new() ################ mask_dict = ResultDict() analyze_mask = False sample_mask = False draw_loss = False ################ iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') res1 = None res2 = [] res3 = [] lamb = [] gamm_g = [] gamm_l1 = [] gamm_l2 = [] iters = [] for iteration in iter_pbar: iter_watch.touch() if debug_sigint.signal_on: debug_1 = iteration == 1 or iteration % 10 == 0 else: debug_1 = False if debug_sigstp.signal_on: debug_2 = True else: debug_2 = False model_detached = C(model_cls(params.detach())) inner_data_s = inner_data.load() loss_detached = model_detached(*inner_data_s) loss_detached.backward() g = model_detached.params.grad.flat.detach() w = model_detached.params.flat.detach() cand_params = [] cand_losses = [] best_loss = 9999999 best_params = None n_samples = 5 for m in range(1): feature = self.feature_gen(g, w, n=n_samples, m_update=(m == 0)) p_size = params.size().unflat() mask_gen_out = self.mask_gen(feature, p_size, n=n_samples) step_out = self.step_gen(feature, n=n_samples) # step & mask genration for i in range(n_samples): mask, _, kld = mask_gen_out[i] step = step_out[i] mask = ParamsFlattener(mask) mask_layout = mask.expand_as(params) step = params.new_from_flat(step) # prunning step = step * mask_layout params_ = params + step # cand_params.append(params_.flat) sparse_params = params_.prune(mask > 1e-6) if sparse_params.size().unflat()['mat_0'][1] == 0: continue if debug_2: import pdb pdb.set_trace() # cand_loss = model(*outer_data_s) sparse_model = C(model_cls(sparse_params.detach())) loss = sparse_model(*inner_data_s) try: if loss < best_loss: best_loss = loss best_params = params_ best_kld = kld except: best_params = params_ best_kld = kld if best_params is not None: params = best_params best_kld = 0 if mode == 'train': model = C(model_cls(params)) optim_loss = model(*outer_data.load()) if torch.isnan(optim_loss): import pdb pdb.set_trace() unroll_losses += optim_loss + best_kld / outer_data.full_size if mode == 'train' and iteration % unroll == 0: meta_optimizer.zero_grad() unroll_losses.backward() # import pdb; pdb.set_trace() nn.utils.clip_grad_value_(self.parameters(), 0.1) meta_optimizer.step() unroll_losses = 0 if not mode == 'train' or iteration % unroll == 0: # self.mask_gen.detach_lambdas_() params = params.detach_() # import pdb; pdb.set_trace() if params is None: import pdb pdb.set_trace() iter_pbar.set_description( f'optim_iteration' f'[optim_loss:{loss_detached.tolist():5.5}') # f' sparse_loss:{sparsity_loss.tolist():5.5}]') # iter_pbar.set_description(f'optim_iteration[loss:{loss_dense.tolist()}/dist:{dist.tolist()}]') # torch.optim.SGD(sparse_params, lr=0.1).step() walltime += iter_watch.touch('interval') ############################################################## text_dir = 'test/analyze_mask' result_dir = 'test/drawloss' sample_dir = 'test/mask_compare' iter_interval = 10 sample_num = 10000 if mode == 'test' and analyze_mask: analyzing_mask(self.mask_gen, layer_size, mode, iter, iter_interval, text_dir) if mode == 'test' and sample_mask: mask_result = sampling_mask(self.mask_gen, layer_size, model_train, params, sample_num, mode, iter, iter_interval, sample_dir) if mask_result is not None: mask_dict.append(mask_result) if mode == 'test' and draw_loss: plot_loss(model_cls=model_cls, model=model_train, params, input_data=data['in_train'].load(), dataset=data['in_train'], feature_gen=self.feature_gen, mask_gen=self.mask_gen, step_gen=self.step_gen, scale_way=None, xmin=-2.0, xmax=0.5, num_x=20, mode=mode, iteration=iter, iter_interval=iter_interval, loss_dir=result_dir) ############################################################## result_dict.append(loss=loss_detached) if not mode == 'train': result_dict.append( walltime=walltime, # **self.params_tracker( # grad=grad, # update=batch_arg.updates, # ) ) return result_dict
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): data = data_cls(training=should_train) model = C(model_cls()) states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) if should_train: self.train() else: self.eval() result_dict = ResultDict() model_dt = None model_losses = 0 #grad_real_prev = None iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') batch_arg = BatchManagerArgument( params=model.params, states=StatesSlicingWrapper(states), update_prev=C(torch.zeros(model.params.size().flat())), mu=C(torch.zeros(model.params.size().flat())), sigma=C(torch.zeros(model.params.size().flat())), ) for iteration in iter_pbar: data_ = data.sample() model_loss = model(*data_) model_losses += model_loss # if iteration == 21: # import pdb; pdb.set_trace() try: make_dot(model_losses) except: import pdb pdb.set_trace() # if iteration == 1 or iteration == 21: # import pdb; pdb.set_trace() model_dt = C(model_cls(params=model.params.detach())) model_loss_dt = model_dt(*data_) # model.params.register_grad_cut_hook_() model_loss_dt.backward() assert model_dt.params.flat.grad is not None grad_cur = model_dt.params.flat.grad # model.params.remove_grad_cut_hook_() batch_arg.update(BatchManagerArgument(grad_cur=grad_cur)) # NOTE: read_only, write_only ######################################################################## batch_manager = OptimizerBatchManager(batch_arg=batch_arg) for get, set in batch_manager.iterator(): grad_cur = get.grad_cur.detach() # free-running update_prev = get.update_prev.detach() points_rel, states = self.point_sampler( grad_cur, update_prev, get.states, self.n_points) points_rel = points_rel * out_mul points_abs = get.params.detach( ) + points_rel # NOTE: check detach losses = self.loss_observer(points_abs, model_cls, data_) update = self.point_selector(points_rel, losses) set.params(get.params + update) set.states(states) set.update_prev(update) set.mu(self.point_sampler.mu) set.sigma(self.point_sampler.sigma) ########################################################################## batch_arg = batch_manager.batch_arg_to_set iter_pbar.set_description('optim_iteration[loss:{:10.6f}]'.format( model_loss.tolist())) if should_train and iteration % unroll == 0: meta_optimizer.zero_grad() try: model_losses.backward() except: import pdb pdb.set_trace() meta_optimizer.step() if not should_train or iteration % unroll == 0: model_losses = 0 batch_arg.detach_() # batch_arg.params.detach_() # batch_arg.states.detach_() # NOTE: just do batch_arg.detach_() model = C(model_cls(params=batch_arg.params)) result_dict.append( step_num=iteration, loss=model_loss, ) if not should_train: result_dict.append(walltime=iter_watch.touch(), **self.params_tracker( grad=grad_cur, update=batch_arg.update_prev, mu=batch_arg.mu, sigma=batch_arg.sigma, )) return result_dict
def meta_optimize(self, meta_optimizer, data, model_cls, optim_it, unroll, out_mul, mode): assert mode in ['train', 'valid', 'test'] if mode == 'train': self.train() # data.new_train_data() inner_data = data.loaders['inner_train'] outer_data = data.loaders['inner_valid'] # inner_data = data.loaders['train'] # outer_data = inner_data drop_mode = 'soft_drop' elif mode == 'valid': self.eval() inner_data = data.loaders['valid'] outer_data = None drop_mode = 'hard_drop' elif mode == 'eval': self.eval() inner_data = data.loaders['test'] outer_data = None drop_mode = 'hard_drop' # data = dataset(mode=mode) model = C(model_cls(sb_mode=self.sb_mode)) model(*data.pseudo_sample()) # layer_sz = sum([v for v in model.params.size().unflat(1, 'mat').values()]) mask_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.activations.mean(0)), self.hidden_sz))) update_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) result_dict = ResultDict() unroll_losses = 0 walltime = 0 update = C(torch.zeros(model.params.size().flat())) batch_arg = BatchManagerArgument( params=model.params, states=StatesSlicingWrapper(update_states), updates=model.params.new_zeros(model.params.size()), ) params = model.params lr_sgd = 0.2 lr_adam = 0.001 sgd = torch.optim.SGD(model.parameters(), lr=lr_sgd) adam = torch.optim.Adam(model.parameters()) adagrad = torch.optim.Adagrad(model.parameters()) # print('###', lr_adam, '####') iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') bp_watch = utils.StopWatch('bp') loss_decay = 1.0 dist_list = [] dense_rank = 0 sparse_rank = 0 dense_rank_list = [] sparse_rank_list = [] for iteration in iter_pbar: iter_watch.touch() # # conventional optimizers with API # model.zero_grad() # loss = model(*inner_data.load()) # loss.backward() # adam.step() # loss_detached = loss model_detached = C(model_cls(params=params.detach())) loss_detached = model_detached(*inner_data.load()) loss_detached.backward() # # conventional optimizers with manual SGD # params = model_detached.params + model_detached.params.grad * (-0.2) # iter_pbar.set_description(f'optim_iteration[loss:{loss_detached}]') masks = [] # sparse_old_loss = [] # sparse_new_loss = [] sparse_loss = [] dense_loss = [] candidates = [] mode = 2 n_sample = 10 if mode == 1: lr = 1.0 elif mode == 2: lr = 0.2 else: raise Exception() for i in range(n_sample): if i == 9: mask = MaskGenerator.topk( model_detached.activations.grad.detach()) else: mask = MaskGenerator.randk( model_detached.activations.grad.detach()) if model == 1: # recompute gradients again using prunned network sparse_params = model_detached.params.prune(mask) model_prunned = C(model_cls(params=sparse_params.detach())) loss_prunned = model_prunned(*inner_data.load()) loss_prunned.backward() grads = model_prunned.params.grad elif mode == 2: # use full gradients but sparsified sparse_params = model_detached.params.prune(mask) grads = model_detached.params.grad.prune(mask) else: raise Exception('unknown model.') sparse_params = sparse_params + grads * (-lr) model_prunned = C(model_cls(params=sparse_params.detach())) loss_prunned = model_prunned(*inner_data.load()) # params = model_detached.params params = model_detached.params.sparse_update( mask, grads * (-lr)) candidates.append(params) model_dense = C(model_cls(params=params)) loss_dense = model_dense(*inner_data.load()) sparse_loss.append(loss_prunned) dense_loss.append(loss_dense) sparse_loss = torch.stack(sparse_loss, 0) dense_loss = torch.stack(dense_loss, 0) min_id = ( sparse_loss.min() == sparse_loss).nonzero().squeeze().tolist() params = candidates[min_id] sparse_order = sparse_loss.sort()[1].float() dense_order = dense_loss.sort()[1].float() dist = (sparse_order - dense_order).abs().sum() dist_list.append(dist) dense_rank += (dense_order == 9).nonzero().squeeze().tolist() sparse_rank += (sparse_order == 9).nonzero().squeeze().tolist() dense_rank_list.append( (dense_order == 9).nonzero().squeeze().tolist()) sparse_rank_list.append( (sparse_order == 9).nonzero().squeeze().tolist()) iter_pbar.set_description( f'optim_iteration[dense_loss:{loss_dense.tolist():5.5}/sparse_loss:{loss_prunned.tolist():5.5}]' ) # iter_pbar.set_description(f'optim_iteration[loss:{loss_dense.tolist()}/dist:{dist.tolist()}]') # torch.optim.SGD(sparse_params, lr=0.1).step() result_dict.append(loss=loss_detached) if not mode == 'train': result_dict.append( walltime=walltime, # **self.params_tracker( # grad=grad, # update=batch_arg.updates, # ) ) dist_mean = torch.stack(dist_list, 0).mean() import pdb pdb.set_trace() return result_dict
def meta_optimize(self, meta_optimizer, data_cls, model_cls, optim_it, unroll, out_mul, should_train=True): data = data_cls(training=should_train) model = C(model_cls()) g_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) u_states = C( OptimizerStates.initial_zeros( (self.n_layers, len(model.params), self.hidden_sz))) if should_train: self.train() else: self.eval() grad_pred_tracker = ParamsIndexTracker(name='grad_pred', n_tracks=10, n_params=len(model.params)) grad_real_tracker = grad_pred_tracker.clone_new_name('grad_real') update_tracker = grad_pred_tracker.clone_new_name('update') result_dict = ResultDict() model_dt = None model_losses = 0 grad_losses = 0 train_with_lazy_tf = False eval_test_grad_loss = True grad_prev = C(torch.zeros(model.params.size().flat())) update_prev = C(torch.zeros(model.params.size().flat())) #grad_real_prev = None grad_real_cur = None grad_real_prev = C(torch.zeros(model.params.size().flat())) iter_pbar = tqdm(range(1, optim_it + 1), 'optim_iteration') iter_watch = utils.StopWatch('optim_iteration') """grad_real_prev: teacher-forcing current input of RNN. required during all of training steps + every 'k' test steps. grad_real_cur: MSE loss target for current output of RNN. required during all the training steps only. """ batch_arg = BatchManagerArgument( params=model.params, grad_prev=grad_prev, update_prev=update_prev, g_states=StatesSlicingWrapper(g_states), u_states=StatesSlicingWrapper(u_states), mse_loss=C(torch.zeros(1)), #updates=OptimizerBatchManager.placeholder(model.params.flat.size()), ) for iteration in iter_pbar: inp, out = data.sample() model_loss = model(inp, out) model_losses += model_loss # if grad_real_cur is not None: # if iteration % self.k == 0 # grad_cur_prev = grad_real_cur # back up previous grad if (iteration + 1) % self.k == 0 or should_train or eval_test_grad_loss: # get grad every step while training # get grad one step earlier than teacher-forcing steps while testing model_dt = C(model_cls(params=model.params.detach())) model_loss_dt = model_dt(inp, out) # model.params.register_grad_cut_hook_() model_loss_dt.backward() # model.params.remove_grad_cut_hook_() grad_real_cur = model_dt.params.flat.grad assert grad_real_cur is not None del model_dt, model_loss_dt if should_train or eval_test_grad_loss: # mse loss target # grad_real_cur is not used as a mse target while testing assert grad_real_cur is not None batch_arg.update( BatchManagerArgument( grad_real_cur=grad_real_cur).immutable().volatile()) if iteration % self.k == 0 or (should_train and not train_with_lazy_tf): # teacher-forcing assert grad_real_prev is not None batch_arg.update( BatchManagerArgument( grad_real_prev=grad_real_prev).immutable().volatile()) grad_real_prev = None ######################################################################## batch_manager = OptimizerBatchManager(batch_arg=batch_arg) for get, set in batch_manager.iterator(): if iteration % self.k == 0 or (should_train and not train_with_lazy_tf): grad_prev = get.grad_real_prev.detach() # teacher-forcing else: grad_prev = get.grad_prev.detach() # free-running update_prev = get.update_prev.detach() grad_cur, update_cur, g_states, u_states = self( grad_prev, update_prev, get.g_states, get.u_states) set.grad_prev(grad_cur) set.update_prev(update_cur) set.g_states(g_states) set.u_states(u_states) set.params(get.params + update_cur * out_mul) if should_train or eval_test_grad_loss: set.mse_loss( self.mse_loss(grad_cur, get.grad_real_cur.detach())) ########################################################################## batch_arg = batch_manager.batch_arg_to_set grad_loss = batch_arg.mse_loss * 100 grad_losses += grad_loss #to_float = lambda x: float(x.data.cpu().numpy()) iter_pbar.set_description( 'optim_iteration[loss(model):{:10.6f}/(grad):{:10.6f}]'.format( model_loss.tolist(), grad_loss.tolist()[0])) if should_train and iteration % unroll == 0: meta_optimizer.zero_grad() losses = model_losses + grad_losses losses.backward() meta_optimizer.step() if not should_train or iteration % unroll == 0: model_losses = 0 grad_losses = 0 batch_arg.detach_() # batch_arg.params.detach_() # batch_arg.g_states.detach_() # batch_arg.u_states.detach_() model = C(model_cls(params=batch_arg.params)) result_dict.append( step_num=iteration, loss=model_loss, grad_loss=grad_loss, ) if not should_train: result_dict.append( walltime=iter_watch.touch(), **grad_pred_tracker.track_num, **grad_pred_tracker(batch_arg.grad_prev), **grad_real_tracker(grad_real_cur), **update_tracker(batch_arg.update_prev), ) if (iteration + 1) % self.k == 0 or should_train or eval_test_grad_loss: grad_real_prev = grad_real_cur grad_real_cur = None return result_dict
def test_neural(name, cfg, learned_params): """Function for meta-test of learned optimizers. """ cfg = Config.merge( [cfg.problem, cfg.neural_optimizers[name].test_args, cfg.args]) data_cls = _get_attr_by_name(cfg.data_cls) model_cls = _get_attr_by_name(cfg.model_cls) optim_cls = _get_optim_by_name(cfg.optim_module) writer = TFWriter(cfg.save_dir, name) if cfg.save_dir else None optimizer = C(optim_cls()) if learned_params is not None: optimizer.params = learned_params meta_optim = None unroll = 1 best_test_loss = 999999 result_all = ResultDict() result_final = ResultDict() # test loss & acc for the whole testset tests_pbar = tqdm(range(cfg.n_test), 'test') # Meta-test for j in tests_pbar: data = data_cls().sample_meta_test() result, params = optimizer.meta_optimize(cfg, meta_optim, data, model_cls, writer, 'test') result_all.append(result) result_final.append( final_inner_test(C(model_cls(params)), data['in_test'], mode='test')) result_final_mean = result_final.mean() last_test_loss = result_final_mean['final_loss_mean'] last_test_acc = result_final_mean['final_acc_mean'] if last_test_loss < best_test_loss: best_test_loss = last_test_loss best_test_acc = last_test_acc result_test = dict( best_test_loss=best_test_loss, best_test_acc=best_test_acc, last_test_loss=last_test_loss, last_test_acc=last_test_acc, ) log_pbar(result_test, tests_pbar) if cfg.save_dir: save_figure(name, cfg.save_dir, writer, result, j, 'test') result_all_mean = result_all.mean(0) # TF-events for inner loop (train & test) if cfg.save_dir: # test_result_mean = ResultDict() for i in range(cfg.iter_test): mean_i = result_all_mean.getitem(i) walltime = result_all_mean['walltime'][i] log_tf_event(writer, 'meta_test_inner', mean_i, i, walltime) # test_result_mean.update(**mean_i, step=i, walltime=walltime) result_all.save(name, cfg.save_dir) result_all.save_as_csv(name, cfg.save_dir, 'meta_test_inner') result_final.save_as_csv(name, cfg.save_dir, 'meta_test_final', trans_1d=True) return result_all
def test_neural(name, save_dir, learned_params, data_cls, model_cls, optim_module, n_test=100, iter_test=100, preproc=False, out_mul=1.0, no_mask=False, k_obsrv=10): """Function for meta-test of learned optimizers. """ data_cls = _get_attr_by_name(data_cls) model_cls = _get_attr_by_name(model_cls) optim_cls = _get_optim_by_name(optim_module) writer = TFWriter(save_dir, name) if save_dir else None optimizer = C(optim_cls()) if learned_params is not None: optimizer.params = learned_params meta_optim = None unroll = 1 best_test_loss = 999999 result_all = ResultDict() result_final = ResultDict() # test loss & acc for the whole testset tests_pbar = tqdm(range(n_test), 'test') # Meta-test for j in tests_pbar: data = data_cls().sample_meta_test() result, params = optimizer.meta_optimize(meta_optim, data, model_cls, iter_test, unroll, out_mul, k_obsrv, no_mask, writer, 'test') result_all.append(result) result_final.append( final_inner_test(C(model_cls(params)), data['in_test'], mode='test')) result_final_mean = result_final.mean() last_test_loss = result_final_mean['loss_mean'] last_test_acc = result_final_mean['acc_mean'] if last_test_loss < best_test_loss: best_test_loss = last_test_loss best_test_acc = last_test_acc result_test = dict( best_test_loss=best_test_loss, best_test_acc=best_test_acc, last_test_loss=last_test_loss, last_test_acc=last_test_acc, ) log_pbar(result_test, tests_pbar) if save_dir: save_figure(name, save_dir, writer, result, j, 'test') result_all_mean = result_all.mean(0) # TF-events for inner loop (train & test) if save_dir: # test_result_mean = ResultDict() for i in range(iter_test): mean_i = result_all_mean.getitem(i) walltime = result_all_mean['walltime'][i] log_tf_event(writer, 'meta_test_inner', mean_i, i, walltime) # test_result_mean.update(**mean_i, step=i, walltime=walltime) result_all.save(name, save_dir) result_all.save_as_csv(name, save_dir, 'meta_test_inner') result_final.save_as_csv(name, save_dir, 'meta_test_final', trans_1d=True) return result_all
def train_neural(name, cfg): """Function for meta-training and meta-validation of learned optimizers. """ cfg = Config.merge( [cfg.problem, cfg.neural_optimizers[name].train_args, cfg.args]) # NOTE: lr will be overwritten here. # Options lr_scheduling = False # learning rate scheduling tr_scheduling = False # Truncation scheduling seperate_lr = False print(f'data_cls: {cfg.data_cls}') print(f'model_cls: {cfg.model_cls}') data_cls = _get_attr_by_name(cfg.data_cls) model_cls = _get_attr_by_name(cfg.model_cls) optim_cls = _get_optim_by_name(cfg.optim_module) meta_optim = cfg.meta_optim optimizer = C(optim_cls()) if len([p for p in optimizer.parameters()]) == 0: return None # no need to be trained if there's no parameter writer = TFWriter(cfg.save_dir, name) if cfg.save_dir else None # TODO: handle variable arguments according to different neural optimziers """meta-optimizer""" meta_optim = {'sgd': 'SGD', 'adam': 'Adam'}[meta_optim.lower()] if cfg.mask_lr == 0.0: print(f'meta optimizer: {meta_optim} / lr: {cfg.lr} / wd: {cfg.wd}\n') meta_optim = getattr(torch.optim, meta_optim)(optimizer.parameters(), lr=cfg.lr, weight_decay=cfg.wd) # scheduler = torch.optim.lr_scheduler.MultiStepLR( # meta_optim, milestones=[1, 2, 3, 4], gamma=0.1) else: print(f'meta optimizer: {meta_optim} / gen_lr: {cfg.lr} ' f'/ mask_lr:{cfg.mask_lr} / wd: {cfg.wd}\n') p_feat = optimizer.feature_gen.parameters() p_step = optimizer.step_gen.parameters() p_mask = optimizer.mask_gen.parameters() meta_optim = getattr(torch.optim, meta_optim)([ { 'params': p_feat, 'lr': cfg.lr, 'weight_decay': cfg.wd }, { 'params': p_step, 'lr': cfg.lr, 'weight_decay': cfg.wd }, { 'params': p_mask, 'lr': cfg.mask_lr }, ]) if lr_scheduling: lr_scheduler = ReduceLROnPlateau(meta_optim, mode='min', factor=0.5, patience=10, verbose=True) # tr_scheduler = Truncation( # cfg.unroll, mode='min', factor=1.5, patience=1, max_len=50, verbose=True) data = data_cls() best_params = None best_valid_loss = 999999 # best_converge = 999999 train_result_mean = ResultDict() valid_result_mean = ResultDict() epoch_pbar = tqdm(range(cfg.n_epoch), 'epoch') for i in epoch_pbar: train_pbar = tqdm(range(cfg.n_train), 'outer_train') # Meta-training for j in train_pbar: train_data = data.sample_meta_train() result, _ = optimizer.meta_optimize(cfg, meta_optim, train_data, model_cls, writer, 'train') result_mean = result.mean() log_pbar(result_mean, train_pbar) if cfg.save_dir: step = (cfg.n_train * i) + j train_result_mean.append(result_mean, step=step) log_tf_event(writer, 'meta_train_outer', result_mean, step) if cfg.save_dir: save_figure(name, cfg.save_dir, writer, result, i, 'train') valid_pbar = tqdm(range(cfg.n_valid), 'outer_valid') result_all = ResultDict() result_final = ResultDict() # Meta-validation for j in valid_pbar: valid_data = data.sample_meta_valid() result, params = optimizer.meta_optimize(cfg, meta_optim, valid_data, model_cls, writer, 'valid') result_all.append(**result) result_final.append( final_inner_test(C(model_cls(params)), valid_data['in_test'], mode='valid')) log_pbar(result.mean(), valid_pbar) result_final_mean = result_final.mean() result_all_mean = result_all.mean().w_postfix('mean') last_valid_loss = result_final_mean['final_loss_mean'] last_valid_acc = result_final_mean['final_acc_mean'] # Learning rate scheduling if lr_scheduling: # and last_valid < 0.5: lr_scheduler.step(last_valid_loss, i) # Truncation scheduling if tr_scheduling: tr_scheduler.step(last_valid_loss, i) # Save TF events and figures if cfg.save_dir: step = cfg.n_train * (i + 1) result_mean = ResultDict( **result_all_mean, **result_final_mean, **get_lr(meta_optim), step=step) #trunc_len=tr_scheduler.len, step=step) valid_result_mean.append(result_mean) log_tf_event(writer, 'meta_valid_outer', result_mean, step) save_figure(name, cfg.save_dir, writer, result_all.mean(0), i, 'valid') # Save the best snapshot if last_valid_loss < best_valid_loss: best_valid_loss = last_valid_loss best_valid_acc = last_valid_acc best_params = copy.deepcopy(optimizer.params) if cfg.save_dir: optimizer.params.save(name, cfg.save_dir) # Update epoch progress bar result_epoch = dict( best_valid_loss=best_valid_loss, best_valid_acc=best_valid_acc, last_valid_loss=last_valid_loss, last_valid_acc=last_valid_acc, ) log_pbar(result_epoch, epoch_pbar) if cfg.save_dir: train_result_mean.save_as_csv(name, cfg.save_dir, 'meta_train_outer', True) valid_result_mean.save_as_csv(name, cfg.save_dir, 'meta_valid_outer', True) return best_params
def loop(mode, outer_steps, inner_steps, log_steps, fig_epochs, inner_lr, log_mask=True, unroll_steps=None, meta_batchsize=0, sampler=None, epoch=1, outer_optim=None, save_path=None): """Args: meta_batchsize(int): If meta_batchsize |m| > 0, gradients for multiple unrollings from each episodes size of |m| will be accumulated in sequence but updated all at once. (Which can be done in parallel when VRAM is large enough, but will be simulated in this code.) If meta_batchsize |m| = 0(default), then update will be performed after each unrollings. """ assert mode in ['train', 'valid', 'test'] assert meta_batchsize >= 0 print(f'Start_of_{mode}.') if mode == 'train' and unroll_steps is None: raise Exception("unroll_steps has to be specied when mode='train'.") if mode != 'train' and unroll_steps is not None: raise Warning("unroll_steps has no effect when mode mode!='train'.") train = True if mode == 'train' else False force_base = True metadata = MetaMultiDataset(split=mode) mask_based = 'query' mask_type = 5 mask_sample = False mask_scale = True easy_ratio = 17 / 50 # 0.3 scale_manual = 0.8 # 1.0 when lr=0.001 and 0.8 when lr=0.00125 inner_lr *= scale_manual if train: if meta_batchsize > 0: # serial processing of meta-minibatch update_epochs = meta_batchsize update_steps = None else: # update at every unrollings update_steps = unroll_steps update_epochs = None assert (update_epochs is None) != (update_steps is None) # for result recordin result_frame = ResultFrame() if save_path: writer = SummaryWriter(os.path.join(save_path, 'tfevent')) for i in range(1, outer_steps + 1): outer_loss = 0 for j, epi in enumerate(metadata.loader(n_batches=1), 1): # initialize base learner model = Model(epi.n_classes) params = model.get_init_params('ours') epi.s = C(epi.s) epi.q = C(epi.q) # baseline parameters params_b0 = C(params.copy('b0')) params_b1 = C(params.copy('b1')) params_b2 = C(params.copy('b2')) result_dict = ResultDict() for k in range(1, inner_steps + 1): # feed support set (baseline) out_s_b0 = model(epi.s, params_b0, None) out_s_b1 = model(epi.s, params_b1, None) out_s_b2 = model(epi.s, params_b2, None) if mask_based == 'support': out = out_s_b1 elif mask_based == 'query': with torch.no_grad(): # test on query set out_q_b1 = model(epi.q, params_b1, mask=None) out = out_q_b1 else: print('WARNING') # attach mask to get loss_s if mask_type == 1: mask = (out.loss.exp().mean().log() - out.loss).exp() elif mask_type == 2: mask = (out.loss.exp().mean().log() / out.loss) elif mask_type == 3: mask = out.loss.mean() / out.loss elif mask_type == 4: mask = out.loss.min() / out.loss elif mask_type == 5 or mask_type == 6: mask_scale = False # weight by magnitude if mask_type == 5: mask = [scale_manual ] * 5 + [(1 - easy_ratio) * scale_manual] * 5 # weight by ordering elif mask_type == 6: if k < inner_steps * easy_ratio: mask = [scale_manual] * 5 + [0.0] * 5 else: mask = [scale_manual] * 5 + [scale_manual] * 5 # sampling from 0 < p < 1 if mask_sample: mask = [np.random.binomial(1, m) for m in mask] mask = C(torch.tensor(mask).float()) else: print('WARNING') if mask_scale: mask = (mask / (mask.max() + 0.05)) # to debug in the middle of running process.class if sig_2.is_active(): import pdb pdb.set_trace() mask = mask.unsqueeze(1) out_s_b1.attach_mask(mask) out_s_b2.attach_mask(mask) # lll = out_s_b0.loss * 0.5 params_b0 = params_b0.sgd_step(out_s_b0.loss.mean(), inner_lr, 'no_grad') params_b1 = params_b1.sgd_step(out_s_b1.loss_masked_mean, inner_lr, 'no_grad') params_b2 = params_b2.sgd_step(out_s_b2.loss_scaled_mean, inner_lr, 'no_grad') with torch.no_grad(): # test on query set out_q_b0 = model(epi.q, params_b0, mask=None) out_q_b1 = model(epi.q, params_b1, mask=None) out_q_b2 = model(epi.q, params_b2, mask=None) # record result result_dict.append(outer_step=epoch * i, inner_step=k, **out_s_b0.as_dict(), **out_s_b1.as_dict(), **out_s_b2.as_dict(), **out_q_b0.as_dict(), **out_q_b1.as_dict(), **out_q_b2.as_dict()) ### end of inner steps (k) ### # append to the dataframe result_frame = result_frame.append_dict( result_dict.index_all(-1).mean_all(-1)) # logging if k % log_steps == 0: # print info msg = Printer.step_info(epoch, mode, i, outer_steps, k, inner_steps, inner_lr) msg += Printer.way_shot_query(epi) # # print mask if not sig_1.is_active() and log_mask: msg += Printer.colorized_mask(mask, fmt="3d", vis_num=20) # print outputs (loss, acc, etc.) msg += Printer.outputs([out_s_b0, out_s_b1, out_s_b2], sig_1.is_active()) msg += Printer.outputs([out_q_b0, out_q_b1, out_q_b2], sig_1.is_active()) print(msg) ### end of meta minibatch (j) ### # tensorboard if save_path and train: step = (epoch * (outer_steps - 1)) + i res = ResultFrame( result_frame[result_frame['outer_step'] == i]) loss = res.get_best_loss().mean() acc = res.get_best_acc().mean() writer.add_scalars('Loss/train', {n: loss[n] for n in loss.index}, step) writer.add_scalars('Acc/train', {n: acc[n] for n in acc.index}, step) # dump figures if save_path and i % fig_epochs == 0: epi.s.save_fig(f'imgs/support', save_path, i) epi.q.save_fig(f'imgs/query', save_path, i) # result_dict['ours_s_mask'].save_fig(f'imgs/masks', save_path, i) result_dict.get_items([ 'b0_s_loss', 'b1_s_loss', 'b2_s_loss' 'b0_q_loss', 'b1_q_loss', 'b2_q_loss' ]).save_csv(f'classwise/{mode}', save_path, i) # distinguishable episodes if not i == outer_steps: print(f'Path for saving: {save_path}') print(f'End_of_episode: {i}') import pdb pdb.set_trace() ### end of episode (i) ### print(f'End_of_{mode}.') # del metadata return sampler, result_frame
def meta_optimize(self, meta_optimizer, data, model_cls, optim_it, unroll, out_mul, tf_writer=None, mode='train'): assert mode in ['train', 'valid', 'test'] if mode == 'train': self.train() else: self.eval() result_dict = ResultDict() unroll_losses = 0 walltime = 0 params = C(model_cls()).params # lr_sgd = 0.2 # lr_adam = 0.001 # sgd = torch.optim.SGD(model.parameters(), lr=lr_sgd) # adam = torch.optim.Adam(model.parameters()) # adagrad = torch.optim.Adagrad(model.parameters()) iter_pbar = tqdm(range(1, optim_it + 1), 'inner_train') iter_watch = utils.StopWatch('inner_train') lr = 0.1 topk = False n_sample = 10 topk_decay = True recompute_grad = False self.topk_ratio = 0.5 # self.topk_ratio = 1.0 set_size = {'layer_0': 500, 'layer_1': 10} # NOTE: do it smarter for iteration in iter_pbar: # # conventional optimizers with API # model.zero_grad() # loss = model(*inner_data.load()) # loss.backward() # adam.step() # loss_detached = loss iter_watch.touch() model_train = C(model_cls(params=params.detach())) train_nll = model_train(*data['in_train'].load()) train_nll.backward() params = model_train.params.detach() grad = model_train.params.grad.detach() act = model_train.activations.detach() # conventional optimizers with manual SGD params_good = params + grad * (-lr) grad_good = grad best_loss_sparse = 999999 for i in range(n_sample): mask_gen = MaskGenerator.topk if topk else MaskGenerator.randk mask = mask_gen(grad=grad, set_size=set_size, topk=self.topk_ratio) if recompute_grad: # recompute gradients again using prunned network params_sparse = params.prune(mask) model_sparse = C(model_cls(params=params_sparse.detach())) loss_sparse = model_sparse(*data['in_train'].load()) loss_sparse.backward() grad_sparse = model_sparse.params.grad else: # use full gradients but sparsified params_sparse = params.prune(mask) grad_sparse = grad.prune(mask) params_sparse_ = params_sparse + grad_sparse * (-lr) # SGD model_sparse = C(model_cls(params=params_sparse_.detach())) loss_sparse = model_sparse(*data['in_train'].load()) if loss_sparse < best_loss_sparse: best_loss_sparse = loss_sparse best_params_sparse = params_sparse_ best_grads = grad_sparse best_mask = mask # below lines are just for evaluation purpose (excluded from time cost) walltime += iter_watch.touch('interval') # decayu topk_ratio if topk_decay: self.topk_ratio *= 0.999 # update! params = params.sparse_update(best_mask, best_grads * (-lr)) # generalization performance model_test = C(model_cls(params=params.detach())) test_nll = model_test(*data['in_test'].load()) # result dict result = dict( train_nll=train_nll.tolist(), test_nll=test_nll.tolist(), sparse_0=self.topk_ratio, sparse_1=self.topk_ratio, ) result_dict.append(result) log_pbar(result, iter_pbar) # desc = [f'{k}: {v:5.5}' for k, v in result.items()] # iter_pbar.set_description(f"inner_train [ {' / '.join(desc)} ]") return result_dict
def loop(mode, data, outer_steps, inner_steps, log_steps, fig_epochs, inner_lr, log_mask=True, unroll_steps=None, meta_batchsize=0, sampler=None, epoch=1, outer_optim=None, save_path=None): """Args: meta_batchsize(int): If meta_batchsize |m| > 0, gradients for multiple unrollings from each episodes size of |m| will be accumulated in sequence but updated all at once. (Which can be done in parallel when VRAM is large enough, but will be simulated in this code.) If meta_batchsize |m| = 0(default), then update will be performed after each unrollings. """ assert mode in ['train', 'valid', 'test'] assert meta_batchsize >= 0 print(f'Start_of_{mode}.') if mode == 'train' and unroll_steps is None: raise Exception("unroll_steps has to be specied when mode='train'.") if mode != 'train' and unroll_steps is not None: raise Warning("unroll_steps has no effect when mode mode!='train'.") samples_per_class = 3 train = True if mode == 'train' else False force_base = True model_mode = 'metric' # split_method = 'inclusive' split_method = 'exclusive' # 100 classes in total if split_method == 'exclusive': meta_support, remainder = data.split_classes(0.1) # 10 classes meta_query = remainder.sample_classes(50) # 50 classes elif split_method == 'inclusive': subdata = data.sample_classes(10) # 50 classes meta_support, meta_query = subdata.split_instances( 0.5) # 5:5 instances else: raise Exception() if train: if meta_batchsize > 0: # serial processing of meta-minibatch update_epochs = meta_batchsize update_steps = None else: # update at every unrollings update_steps = unroll_steps update_epochs = None assert (update_epochs is None) != (update_steps is None) # for result recordin result_frame = ResultFrame() if save_path: writer = SummaryWriter(os.path.join(save_path, 'tfevent')) for i in range(1, outer_steps + 1): fc_lr = 0.01 metric_lr = 0.01 import pdb pdb.set_trace() model_fc = Model(len(meta_support), mode='fc') model_metric = Model(len(meta_support), mode='metric') # baseline parameters params_fc = C(model_fc.get_init_params('fc')) params_metric = C(model_metric.get_init_params('metric')) result_dict = ResultDict() episode_iterator = EpisodeIterator( support=meta_support, query=meta_query, split_ratio=0.5, resample_every_iteration=True, inner_steps=inner_steps, samples_per_class=samples_per_class, num_workers=2, pin_memory=True, ) episode_iterator.sample_episode() for k, (meta_s, meta_q) in enumerate(episode_iterator, 1): # import pdb; pdb.set_trace() # feed support set # [standard network] out_s_fc = model_fc(meta_s, params_fc, None) with torch.no_grad(): # embedding space metric params_fc_m = C(params_fc.copy('fc_m')) out_s_fc_m = model_metric(meta_s, params_fc_m, mask=None) # [prototypical network] out_s_metric = model_metric(meta_s, params_metric, None) # inner gradient step (baseline) params_fc = params_fc.sgd_step(out_s_fc.loss.mean(), fc_lr, 'no_grad') params_metric = params_metric.sgd_step(out_s_metric.loss.mean(), metric_lr, 'no_grad') # test on query set with torch.no_grad(): if split_method == 'inclusive': out_q_fc = model_fc(meta_q, params_fc, mask=None) params_fc_m = C(params_fc.copy('fc_m')) out_q_fc_m = model_metric(meta_q, params_fc_m, mask=None) out_q_metric = model_metric(meta_q, params_metric, mask=None) if split_method == 'inclusive': outs = [ out_s_fc, out_s_fc_m, out_s_metric, out_q_fc, out_q_fc_m, out_q_metric ] elif split_method == 'exclusive': outs = [ out_s_fc, out_s_fc_m, out_s_metric, out_q_fc_m, out_q_metric ] # record result result_dict.append(outer_step=epoch * i, inner_step=k, **ModelOutput.as_merged_dict(outs)) ### end of inner steps (k) ### # append to the dataframe result_frame = result_frame.append_dict( result_dict.index_all(-1).mean_all(-1)) # logging if k % log_steps == 0: # print info msg = Printer.step_info(epoch, mode, i, outer_steps, k, inner_steps, inner_lr) # msg += Printer.way_shot_query(epi) # print mask if not sig_1.is_active() and log_mask: # print outputs (loss, acc, etc.) msg += Printer.outputs(outs, True) print(msg) # to debug in the middle of running process. if k == inner_steps and sig_2.is_active(): import pdb pdb.set_trace() # # tensorboard # if save_path and train: # step = (epoch * (outer_steps - 1)) + i # res = ResultFrame(result_frame[result_frame['outer_step'] == i]) # loss = res.get_best_loss().mean() # acc = res.get_best_acc().mean() # writer.add_scalars( # 'Loss/train', {n: loss[n] for n in loss.index}, step) # writer.add_scalars('Acc/train', {n: acc[n] for n in acc.index}, step) # # # dump figures # if save_path and i % fig_epochs == 0: # meta_s.s.save_fig(f'imgs/meta_support', save_path, i) # meta_q.s.save_fig(f'imgs/meta_query', save_path, i) # result_dict['ours_s_mask'].save_fig(f'imgs/masks', save_path, i) # result_dict.get_items(['ours_s_mask', 'ours_s_loss', # 'ours_s_loss_masked', 'b0_s_loss', 'b1_s_loss', # 'ours_q_loss', 'b0_q_loss', 'b1_q_loss']).save_csv( # f'classwise/{mode}', save_path, i) # distinguishable episodes if not i == outer_steps: print(f'Path for saving: {save_path}') print(f'End_of_episode: {i}') ### end of episode (i) ### print(f'End_of_{mode}.') # del metadata return sampler, result_frame