def loop(mode, data, outer_steps, inner_steps, log_steps, fig_epochs, inner_lr, log_mask=True, unroll_steps=None, meta_batchsize=0, sampler=None, epoch=1, outer_optim=None, save_path=None, is_RL=False): ##################################################################### """Args: meta_batchsize(int): If meta_batchsize |m| > 0, gradients for multiple unrollings from each episodes size of |m| will be accumulated in sequence but updated all at once. (Which can be done in parallel when VRAM is large enough, but will be simulated in this code.) If meta_batchsize |m| = 0(default), then update will be performed after each unrollings. """ assert mode in ['train', 'valid', 'test'] assert meta_batchsize >= 0 print(f'Start_of_{mode}.') if mode == 'train' and unroll_steps is None: raise Exception("unroll_steps has to be specied when mode='train'.") if mode != 'train' and unroll_steps is not None: raise Warning("unroll_steps has no effect when mode mode!='train'.") train = True if mode == 'train' else False force_base = True # TODO: to gin configuration fc_pulling = False # does not support for now class_balanced = False if class_balanced: classes_per_episode = 10 samples_per_class = 3 else: batch_size = 128 sample_split_ratio = 0.5 # sample_split_ratio = None anneal_outer_steps = 50 concrete_resample = True detach_param = False split_method = {1: 'inclusive', 2: 'exclusive'}[1] sampler_type = {1: 'pre_sampler', 2: 'post_sampler'}[2] ##################################################################### mask_unit = { 0: MaskUnit.SAMPLE, 1: MaskUnit.CLASS, }[0] mask_dist = { 0: MaskDist.SOFT, 1: MaskDist.DISCRETE, 2: MaskDist.CONCRETE, 3: MaskDist.RL, }[3 if is_RL else 2] mask_mode = MaskMode(mask_unit, mask_dist) ##################################################################### # scheduler inner_step_scheduler = InnerStepScheduler( outer_steps, inner_steps, anneal_outer_steps) # 100 classes in total if split_method == 'exclusive': # meta_support, remainder = data.split_classes(0.1) # 10 classes # meta_query = remainder.sample_classes(50) # 50 classes meta_support, meta_query = data.split_classes(1 / 5) # 1(100) : 4(400) # meta_support, meta_query = data.split_classes(0.3) # 30 : 70 classes elif split_method == 'inclusive': # subdata = data.sample_classes(10) # 50 classes meta_support, meta_query = data.split_instances(0.5) # 5:5 instances else: raise Exception() if train: if meta_batchsize > 0: # serial processing of meta-minibatch update_epochs = meta_batchsize update_steps = None else: # update at every unrollings update_steps = unroll_steps update_epochs = None assert (update_epochs is None) != (update_steps is None) # for result recordin result_frame = ResultFrame() if save_path: writer = SummaryWriter(os.path.join(save_path, 'tfevent')) ################################## if is_RL: env = Environment() policy = Policy(sampler) neg_rw = None ################################## for i in range(1, outer_steps + 1): outer_loss = 0 result_dict = ResultDict() # initialize sampler sampler.initialize() # initialize base learner model_q = Model(len(meta_support), mode='metric') if fc_pulling: model_s = Model(len(meta_support), mode='fc') params = C(model_s.get_init_params('ours')) else: model_s = model_q params = C(model_s.get_init_params('ours')) ################################## if is_RL: policy.reset() # env.reset() ################################## if not train or force_base: """baseline 0: naive single task learning baseline 1: single task learning with the same loss scale """ # baseline parameters params_b0 = C(params.copy('b0'), 4) params_b1 = C(params.copy('b1'), 4) if class_balanced: batch_sampler = BalancedBatchSampler.get(class_size, sample_size) else: batch_sampler = RandomBatchSampler.get(batch_size) # episode iterator episode_iterator = MetaEpisodeIterator( meta_support=meta_support, meta_query=meta_query, batch_sampler=batch_sampler, match_inner_classes=match_inner_classes, inner_steps=inner_step_scheduler(i, verbose=True), sample_split_ratio=sample_split_ratio, num_workers=4, pin_memory=True, ) do_sample = True for k, (meta_s, meta_q) in episode_iterator(do_sample): outs = [] meta_s = meta_s.cuda() ##################################################################### if do_sample: do_sample = False # task encoding (very first step and right after a meta-update) with torch.set_grad_enabled(train): def feature_fn(): # determein whether to untilize model features if sampler_type == 'post_sampler': return model_s(meta_s, params) elif sampler_type == 'pre_sampler': return None # sampler do the work! mask, lr_ = sampler(mask_mode, feature_fn) # out_for_sampler = model_s(meta_s, params, debug=do_sample) # mask_, lr = sampler( # pairwise_dist=out_for_sampler.pairwise_dist, # classwise_loss=out_for_sampler.loss, # classwise_acc=out_for_sampler.acc, # n_classes=out_for_sampler.n_classes, # mask_mode=mask_mode, # ) # sample from concrete distribution # while keeping original distribution for simple resampling if mask_mode.dist is MaskDist.CONCRETE: mask = mask_.rsample() else: mask = mask_ if is_RL: mask, neg_rw = policy.predict(mask) if neg_rw: # TODO fix iteratively [0 class]s are sampled. do_sample = True print('[!] select 0 class!') policy.rewards.append(neg_rw) import pdb pdb.set_trace() # policy.update() continue else: # we can simply resample masks when using concrete distribution # without seriously going through the sampler if mask_mode.dist is MaskDist.CONCRETE and concrete_resample: mask = mask_.rsample() ##################################################################### # use learned learning rate if available # lr = inner_lr if lr is None else lr # inner_lr: preset / lr: learned # train on support set params, mask = C([params, mask], 2) ##################################################################### out_s = model_s(meta_s, params, mask=mask, mask_mode=mask_mode) # out_s_loss_masked = out_s.loss_masked outs.append(out_s) # inner gradient step out_s_loss_mean = out_s.loss.mean() if mask_mode == MaskMode.RL \ else out_s.loss_masked_mean out_s_loss_mean, lr = C([out_s_loss_mean, lr], 2) params = params.sgd_step( out_s_loss_mean, lr, second_order=False if is_RL else True, detach_param=True if is_RL else detach_param ) ##################################################################### # baseline if not train or force_base: out_s_b0 = model_s(meta_s, params_b0) out_s_b1 = model_s(meta_s, params_b1) outs.extend([out_s_b0, out_s_b1]) # attach mask to get loss_s out_s_b1.attach_mask(mask) # inner gradient step (baseline) params_b0 = params_b0.sgd_step(out_s_b0.loss.mean(), inner_lr) params_b1 = params_b1.sgd_step(out_s_b1.loss_scaled_mean, lr.detach()) meta_s = meta_s.cpu() meta_q = meta_q.cuda() # test on query set with torch.set_grad_enabled(train): params = C(params, 3) out_q = model_q(meta_q, params) outs.append(out_q) # baseline if not train or force_base: with torch.no_grad(): # test on query set out_q_b0 = model_q(meta_q, params_b0) out_q_b1 = model_q(meta_q, params_b1) outs.extend([out_q_b0, out_q_b1]) ################################## if is_RL: reward = env.step(outs) policy.rewards.append(reward) outs.append(policy) ################################## # record result result_dict.append( outer_step=epoch * i, inner_step=k, **ModelOutput.as_merged_dict(outs) ) # append to the dataframe result_frame = result_frame.append_dict( result_dict.index_all(-1).mean_all(-1)) # logging if k % log_steps == 0: logger.step_info( epoch, mode, i, outer_steps, k, inner_step_scheduler(i), lr) logger.split_info(meta_support, meta_query, episode_iterator) logger.colorized_mask( mask, fmt="2d", vis_num=20, cond=not sig_1.is_active() and log_mask) logger.outputs(outs, print_conf=sig_1.is_active()) logger.flush() # to debug in the middle of running process. if k == inner_steps and sig_2.is_active(): import pdb pdb.set_trace() # compute outer gradient if train and (k % unroll_steps == 0 or k == inner_steps): ##################################################################### if not is_RL: outer_loss += out_q.loss.mean() outer_loss.backward() outer_loss = 0 params.detach_().requires_grad_() mask = mask.detach() lr = lr.detach() do_sample = True ##################################################################### if not train: if not is_RL: # when params is not leaf node created by user, # requires_grad attribute is False by default. params.requires_grad_() # meta(outer) learning if train and update_steps and k % update_steps == 0: # when you have meta_batchsize == 0, update_steps == unroll_steps if is_RL: policy.update() else: outer_optim.step() sampler.zero_grad() ### end of inner steps (k) ### if train and update_epochs and i % update_epochs == 0: # when you have meta_batchsize > 0, update_epochs == meta_batchsize if is_RL: policy.update() else: outer_optim.step() sampler.zero_grad() if update_epochs: print(f'Meta-batchsize is {meta_batchsize}: Sampler updated.') if train and update_steps: print(f'Meta-batchsize is zero. Updating after every unrollings.') # tensorboard if save_path and train: step = (epoch * (outer_steps - 1)) + i res = ResultFrame(result_frame[result_frame['outer_step'] == i]) loss = res.get_best_loss().mean() acc = res.get_best_acc().mean() writer.add_scalars( 'Loss/train', {n: loss[n] for n in loss.index}, step) writer.add_scalars('Acc/train', {n: acc[n] for n in acc.index}, step) # dump figures if save_path and i % fig_epochs == 0: # meta_s.s.save_fig(f'imgs/meta_support', save_path, i) # meta_q.s.save_fig(f'imgs/meta_query', save_path, i) result_dict['ours_s_mask'].save_fig(f'imgs/masks', save_path, i) result_dict.get_items( ['ours_s_mask', 'ours_s_loss', 'ours_s_loss_masked', 'b0_s_loss', 'b1_s_loss', 'ours_q_loss', 'b0_q_loss', 'b1_q_loss'] ).save_csv(f'classwise/{mode}', save_path, i) # distinguishable episodes if not i == outer_steps: print(f'Path for saving: {save_path}') print(f'End_of_episode: {i}') ### end of episode (i) ### print(f'End_of_{mode}.') # del metadata return sampler, result_frame
def loop(mode, outer_steps, inner_steps, log_steps, fig_epochs, inner_lr, log_mask=True, unroll_steps=None, meta_batchsize=0, sampler=None, epoch=1, outer_optim=None, save_path=None): """Args: meta_batchsize(int): If meta_batchsize |m| > 0, gradients for multiple unrollings from each episodes size of |m| will be accumulated in sequence but updated all at once. (Which can be done in parallel when VRAM is large enough, but will be simulated in this code.) If meta_batchsize |m| = 0(default), then update will be performed after each unrollings. """ assert mode in ['train', 'valid', 'test'] assert meta_batchsize >= 0 print(f'Start_of_{mode}.') if mode == 'train' and unroll_steps is None: raise Exception("unroll_steps has to be specied when mode='train'.") if mode != 'train' and unroll_steps is not None: raise Warning("unroll_steps has no effect when mode mode!='train'.") train = True if mode == 'train' else False force_base = True metadata = MetaMultiDataset(split=mode) mask_based = 'query' mask_type = 5 mask_sample = False mask_scale = True easy_ratio = 17 / 50 # 0.3 scale_manual = 0.8 # 1.0 when lr=0.001 and 0.8 when lr=0.00125 inner_lr *= scale_manual if train: if meta_batchsize > 0: # serial processing of meta-minibatch update_epochs = meta_batchsize update_steps = None else: # update at every unrollings update_steps = unroll_steps update_epochs = None assert (update_epochs is None) != (update_steps is None) # for result recordin result_frame = ResultFrame() if save_path: writer = SummaryWriter(os.path.join(save_path, 'tfevent')) for i in range(1, outer_steps + 1): outer_loss = 0 for j, epi in enumerate(metadata.loader(n_batches=1), 1): # initialize base learner model = Model(epi.n_classes) params = model.get_init_params('ours') epi.s = C(epi.s) epi.q = C(epi.q) # baseline parameters params_b0 = C(params.copy('b0')) params_b1 = C(params.copy('b1')) params_b2 = C(params.copy('b2')) result_dict = ResultDict() for k in range(1, inner_steps + 1): # feed support set (baseline) out_s_b0 = model(epi.s, params_b0, None) out_s_b1 = model(epi.s, params_b1, None) out_s_b2 = model(epi.s, params_b2, None) if mask_based == 'support': out = out_s_b1 elif mask_based == 'query': with torch.no_grad(): # test on query set out_q_b1 = model(epi.q, params_b1, mask=None) out = out_q_b1 else: print('WARNING') # attach mask to get loss_s if mask_type == 1: mask = (out.loss.exp().mean().log() - out.loss).exp() elif mask_type == 2: mask = (out.loss.exp().mean().log() / out.loss) elif mask_type == 3: mask = out.loss.mean() / out.loss elif mask_type == 4: mask = out.loss.min() / out.loss elif mask_type == 5 or mask_type == 6: mask_scale = False # weight by magnitude if mask_type == 5: mask = [scale_manual ] * 5 + [(1 - easy_ratio) * scale_manual] * 5 # weight by ordering elif mask_type == 6: if k < inner_steps * easy_ratio: mask = [scale_manual] * 5 + [0.0] * 5 else: mask = [scale_manual] * 5 + [scale_manual] * 5 # sampling from 0 < p < 1 if mask_sample: mask = [np.random.binomial(1, m) for m in mask] mask = C(torch.tensor(mask).float()) else: print('WARNING') if mask_scale: mask = (mask / (mask.max() + 0.05)) # to debug in the middle of running process.class if sig_2.is_active(): import pdb pdb.set_trace() mask = mask.unsqueeze(1) out_s_b1.attach_mask(mask) out_s_b2.attach_mask(mask) # lll = out_s_b0.loss * 0.5 params_b0 = params_b0.sgd_step(out_s_b0.loss.mean(), inner_lr, 'no_grad') params_b1 = params_b1.sgd_step(out_s_b1.loss_masked_mean, inner_lr, 'no_grad') params_b2 = params_b2.sgd_step(out_s_b2.loss_scaled_mean, inner_lr, 'no_grad') with torch.no_grad(): # test on query set out_q_b0 = model(epi.q, params_b0, mask=None) out_q_b1 = model(epi.q, params_b1, mask=None) out_q_b2 = model(epi.q, params_b2, mask=None) # record result result_dict.append(outer_step=epoch * i, inner_step=k, **out_s_b0.as_dict(), **out_s_b1.as_dict(), **out_s_b2.as_dict(), **out_q_b0.as_dict(), **out_q_b1.as_dict(), **out_q_b2.as_dict()) ### end of inner steps (k) ### # append to the dataframe result_frame = result_frame.append_dict( result_dict.index_all(-1).mean_all(-1)) # logging if k % log_steps == 0: # print info msg = Printer.step_info(epoch, mode, i, outer_steps, k, inner_steps, inner_lr) msg += Printer.way_shot_query(epi) # # print mask if not sig_1.is_active() and log_mask: msg += Printer.colorized_mask(mask, fmt="3d", vis_num=20) # print outputs (loss, acc, etc.) msg += Printer.outputs([out_s_b0, out_s_b1, out_s_b2], sig_1.is_active()) msg += Printer.outputs([out_q_b0, out_q_b1, out_q_b2], sig_1.is_active()) print(msg) ### end of meta minibatch (j) ### # tensorboard if save_path and train: step = (epoch * (outer_steps - 1)) + i res = ResultFrame( result_frame[result_frame['outer_step'] == i]) loss = res.get_best_loss().mean() acc = res.get_best_acc().mean() writer.add_scalars('Loss/train', {n: loss[n] for n in loss.index}, step) writer.add_scalars('Acc/train', {n: acc[n] for n in acc.index}, step) # dump figures if save_path and i % fig_epochs == 0: epi.s.save_fig(f'imgs/support', save_path, i) epi.q.save_fig(f'imgs/query', save_path, i) # result_dict['ours_s_mask'].save_fig(f'imgs/masks', save_path, i) result_dict.get_items([ 'b0_s_loss', 'b1_s_loss', 'b2_s_loss' 'b0_q_loss', 'b1_q_loss', 'b2_q_loss' ]).save_csv(f'classwise/{mode}', save_path, i) # distinguishable episodes if not i == outer_steps: print(f'Path for saving: {save_path}') print(f'End_of_episode: {i}') import pdb pdb.set_trace() ### end of episode (i) ### print(f'End_of_{mode}.') # del metadata return sampler, result_frame
def loop(mode, data, outer_steps, inner_steps, log_steps, fig_epochs, inner_lr, log_mask=True, unroll_steps=None, meta_batchsize=0, sampler=None, epoch=1, outer_optim=None, save_path=None): """Args: meta_batchsize(int): If meta_batchsize |m| > 0, gradients for multiple unrollings from each episodes size of |m| will be accumulated in sequence but updated all at once. (Which can be done in parallel when VRAM is large enough, but will be simulated in this code.) If meta_batchsize |m| = 0(default), then update will be performed after each unrollings. """ assert mode in ['train', 'valid', 'test'] assert meta_batchsize >= 0 print(f'Start_of_{mode}.') if mode == 'train' and unroll_steps is None: raise Exception("unroll_steps has to be specied when mode='train'.") if mode != 'train' and unroll_steps is not None: raise Warning("unroll_steps has no effect when mode mode!='train'.") samples_per_class = 3 train = True if mode == 'train' else False force_base = True model_mode = 'metric' # split_method = 'inclusive' split_method = 'exclusive' # 100 classes in total if split_method == 'exclusive': meta_support, remainder = data.split_classes(0.1) # 10 classes meta_query = remainder.sample_classes(50) # 50 classes elif split_method == 'inclusive': subdata = data.sample_classes(10) # 50 classes meta_support, meta_query = subdata.split_instances( 0.5) # 5:5 instances else: raise Exception() if train: if meta_batchsize > 0: # serial processing of meta-minibatch update_epochs = meta_batchsize update_steps = None else: # update at every unrollings update_steps = unroll_steps update_epochs = None assert (update_epochs is None) != (update_steps is None) # for result recordin result_frame = ResultFrame() if save_path: writer = SummaryWriter(os.path.join(save_path, 'tfevent')) for i in range(1, outer_steps + 1): fc_lr = 0.01 metric_lr = 0.01 import pdb pdb.set_trace() model_fc = Model(len(meta_support), mode='fc') model_metric = Model(len(meta_support), mode='metric') # baseline parameters params_fc = C(model_fc.get_init_params('fc')) params_metric = C(model_metric.get_init_params('metric')) result_dict = ResultDict() episode_iterator = EpisodeIterator( support=meta_support, query=meta_query, split_ratio=0.5, resample_every_iteration=True, inner_steps=inner_steps, samples_per_class=samples_per_class, num_workers=2, pin_memory=True, ) episode_iterator.sample_episode() for k, (meta_s, meta_q) in enumerate(episode_iterator, 1): # import pdb; pdb.set_trace() # feed support set # [standard network] out_s_fc = model_fc(meta_s, params_fc, None) with torch.no_grad(): # embedding space metric params_fc_m = C(params_fc.copy('fc_m')) out_s_fc_m = model_metric(meta_s, params_fc_m, mask=None) # [prototypical network] out_s_metric = model_metric(meta_s, params_metric, None) # inner gradient step (baseline) params_fc = params_fc.sgd_step(out_s_fc.loss.mean(), fc_lr, 'no_grad') params_metric = params_metric.sgd_step(out_s_metric.loss.mean(), metric_lr, 'no_grad') # test on query set with torch.no_grad(): if split_method == 'inclusive': out_q_fc = model_fc(meta_q, params_fc, mask=None) params_fc_m = C(params_fc.copy('fc_m')) out_q_fc_m = model_metric(meta_q, params_fc_m, mask=None) out_q_metric = model_metric(meta_q, params_metric, mask=None) if split_method == 'inclusive': outs = [ out_s_fc, out_s_fc_m, out_s_metric, out_q_fc, out_q_fc_m, out_q_metric ] elif split_method == 'exclusive': outs = [ out_s_fc, out_s_fc_m, out_s_metric, out_q_fc_m, out_q_metric ] # record result result_dict.append(outer_step=epoch * i, inner_step=k, **ModelOutput.as_merged_dict(outs)) ### end of inner steps (k) ### # append to the dataframe result_frame = result_frame.append_dict( result_dict.index_all(-1).mean_all(-1)) # logging if k % log_steps == 0: # print info msg = Printer.step_info(epoch, mode, i, outer_steps, k, inner_steps, inner_lr) # msg += Printer.way_shot_query(epi) # print mask if not sig_1.is_active() and log_mask: # print outputs (loss, acc, etc.) msg += Printer.outputs(outs, True) print(msg) # to debug in the middle of running process. if k == inner_steps and sig_2.is_active(): import pdb pdb.set_trace() # # tensorboard # if save_path and train: # step = (epoch * (outer_steps - 1)) + i # res = ResultFrame(result_frame[result_frame['outer_step'] == i]) # loss = res.get_best_loss().mean() # acc = res.get_best_acc().mean() # writer.add_scalars( # 'Loss/train', {n: loss[n] for n in loss.index}, step) # writer.add_scalars('Acc/train', {n: acc[n] for n in acc.index}, step) # # # dump figures # if save_path and i % fig_epochs == 0: # meta_s.s.save_fig(f'imgs/meta_support', save_path, i) # meta_q.s.save_fig(f'imgs/meta_query', save_path, i) # result_dict['ours_s_mask'].save_fig(f'imgs/masks', save_path, i) # result_dict.get_items(['ours_s_mask', 'ours_s_loss', # 'ours_s_loss_masked', 'b0_s_loss', 'b1_s_loss', # 'ours_q_loss', 'b0_q_loss', 'b1_q_loss']).save_csv( # f'classwise/{mode}', save_path, i) # distinguishable episodes if not i == outer_steps: print(f'Path for saving: {save_path}') print(f'End_of_episode: {i}') ### end of episode (i) ### print(f'End_of_{mode}.') # del metadata return sampler, result_frame