def get_dataset(args): # Prepare dataloaders train_dataset, val_dataset = dataloaders.base.__dict__[args.dataset]( args.dataroot, args.train_aug) if args.n_permutation > 0: train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen( train_dataset, val_dataset, args.n_permutation, remap_class=not args.no_class_remap) else: train_dataset_splits, val_dataset_splits, task_output_space = SplitGen( train_dataset, val_dataset, first_split_sz=args.first_split_size, other_split_sz=args.other_split_size, rand_split=args.rand_split, remap_class=not args.no_class_remap) return train_dataset_splits, val_dataset_splits, task_output_space
def run(args): # Prepare dataloaders train_dataset, val_dataset = dataloaders.base.__dict__[args.dataset](args.dataroot, args.train_aug) if args.n_permutation > 0: train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen(train_dataset, val_dataset, args.n_permutation, remap_class=not args.no_class_remap) else: train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(train_dataset, val_dataset, first_split_sz=args.first_split_size, other_split_sz=args.other_split_size, rand_split=args.rand_split, remap_class=not args.no_class_remap) task_names = sorted(list(task_output_space.keys()), key=int) if len(args.eps_val) == 1: args.eps_val = [args.eps_val[0]] * len(task_names) if len(args.eps_max) == 1: args.eps_max = [args.eps_max[0]] * len(task_names) if len(args.eps_epoch) == 1: args.eps_epoch = [args.eps_epoch[0]] * len(task_names) if len(args.kappa_epoch) == 1: args.kappa_epoch = [args.kappa_epoch[0]] * len(task_names) if len(args.schedule) == 1: args.schedule = [args.schedule[0]] * len(task_names) # Prepare the Agent (model) agent_config = {'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay, 'schedule': args.schedule, 'model_type': args.model_type, 'model_name': args.model_name, 'model_weights': args.model_weights, 'out_dim': {'All': args.force_out_dim} if args.force_out_dim > 0 else task_output_space, 'optimizer': args.optimizer, 'print_freq': args.print_freq, 'gpuid': args.gpuid, 'reg_coef': args.reg_coef, 'force_out_dim': args.force_out_dim, 'clipping': args.clipping, 'eps_per_model': args.eps_per_model, 'milestones': args.milestones, 'dataset_name': args.dataset } agent = agents.__dict__[args.agent_type].__dict__[args.agent_name](agent_config) print(agent.model) print('#parameter of model:', agent.count_parameter()) # Decide split ordering print('Task order:', task_names) if args.rand_split_order: shuffle(task_names) print('Shuffled task order:', task_names) acc_table = OrderedDict() if args.offline_training: # Non-incremental learning / offline_training / measure the upper-bound performance task_names = ['All'] train_dataset_all = torch.utils.data.ConcatDataset(train_dataset_splits.values()) val_dataset_all = torch.utils.data.ConcatDataset(val_dataset_splits.values()) train_loader = DataLoader(train_dataset_all, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) val_loader = DataLoader(val_dataset_all, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) agent.learn_batch(train_loader, val_loader) acc_table['All'] = {} acc_table['All']['All'] = agent.validation(val_loader) else: # Incremental learning # Feed data to agent and evaluate agent's performance for i in range(len(task_names)): train_name = task_names[i] agent.current_task = int(task_names[i]) print('======================', train_name, '=======================') train_loader = DataLoader(train_dataset_splits[train_name], batch_size=args.batch_size, shuffle=True, num_workers=args.workers) val_loader = DataLoader(val_dataset_splits[train_name], batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.incremental_class: agent.add_valid_output_dim(task_output_space[train_name]) if args.eps_max: agent.eps_scheduler.set_end(args.eps_max[i]) agent.kappa_scheduler.end = args.kappa_min iter_on_batch = len(train_loader) agent.kappa_scheduler.calc_coefficient(args.kappa_min-1, args.kappa_epoch[i], iter_on_batch) agent.eps_scheduler.calc_coefficient(args.eps_val[i], args.eps_epoch[i], iter_on_batch) agent.kappa_scheduler.current, agent.eps_scheduler.current = 1, 0 if agent.multihead: agent.current_head = str(train_name) print(f"before batch eps: {agent.eps_scheduler.current}, kappa: {agent.kappa_scheduler.current}") agent.learn_batch(train_loader, val_loader) # Learn print(f"after batch eps: {agent.eps_scheduler.current}, kappa: {agent.kappa_scheduler.current}") if args.clipping: agent.save_params() agent.model.print_eps(agent.current_head) agent.model.reset_importance() # Evaluate acc_table[train_name] = OrderedDict() for j in range(i+1): val_name = task_names[j] print('validation split name:', val_name) val_data = val_dataset_splits[val_name] if not args.eval_on_train_set else train_dataset_splits[val_name] val_loader = DataLoader(val_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) acc_table[val_name][train_name] = agent.validation(val_loader) agent.validation_with_move_weights(val_loader) agent.tb.close() return acc_table, task_names
def run(args): if not os.path.exists('outputs'): os.mkdir('outputs') # Prepare dataloaders train_dataset, val_dataset = dataloaders.base.__dict__[args.dataset]( args.dataroot, args.train_aug) if args.n_permutation > 0: train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen( train_dataset, val_dataset, args.n_permutation, remap_class=not args.no_class_remap) else: train_dataset_splits, val_dataset_splits, task_output_space = SplitGen( train_dataset, val_dataset, first_split_sz=args.first_split_size, other_split_sz=args.other_split_size, rand_split=args.rand_split, remap_class=not args.no_class_remap) # Prepare the Agent (model) agent_config = { 'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay, 'schedule': args.schedule, 'model_type': args.model_type, 'model_name': args.model_name, 'model_weights': args.model_weights, 'out_dim': { 'All': args.force_out_dim } if args.force_out_dim > 0 else task_output_space, 'optimizer': args.optimizer, 'print_freq': args.print_freq, 'gpuid': args.gpuid, 'reg_coef': args.reg_coef } agent = agents.__dict__[args.agent_type].__dict__[args.agent_name]( agent_config) print(agent.model) print('#parameter of model:', agent.count_parameter()) # Decide split ordering task_names = sorted(list(task_output_space.keys()), key=int) print('Task order:', task_names) if args.rand_split_order: shuffle(task_names) print('Shuffled task order:', task_names) acc_table = OrderedDict() if args.offline_training: # Non-incremental learning / offline_training / measure the upper-bound performance task_names = ['All'] train_dataset_all = torch.utils.data.ConcatDataset( train_dataset_splits.values()) val_dataset_all = torch.utils.data.ConcatDataset( val_dataset_splits.values()) train_loader = torch.utils.data.DataLoader(train_dataset_all, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset_all, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) agent.learn_batch(train_loader, val_loader) acc_table['All'] = {} acc_table['All']['All'] = agent.validation(val_loader) else: # Incremental learning # Feed data to agent and evaluate agent's performance for i in range(len(task_names)): train_name = task_names[i] print('======================', train_name, '=======================') train_loader = torch.utils.data.DataLoader( train_dataset_splits[train_name], batch_size=args.batch_size, shuffle=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader( val_dataset_splits[train_name], batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.incremental_class: agent.add_valid_output_dim(task_output_space[train_name]) # Learn agent.learn_batch(train_loader, val_loader) # Evaluate acc_table[train_name] = OrderedDict() for j in range(i + 1): val_name = task_names[j] print('validation split name:', val_name) val_data = val_dataset_splits[ val_name] if not args.eval_on_train_set else train_dataset_splits[ val_name] val_loader = torch.utils.data.DataLoader( val_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) acc_table[val_name][train_name] = agent.validation(val_loader) return acc_table, task_names
def run(args, rand_seed): print('Random seeed', rand_seed) if args.benchmark: print('benchamrked') torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.manual_seed(rand_seed) np.random.seed(rand_seed) random.seed(rand_seed) # sparse_wts = args.sparse_wt # for sparse_wt in sparse_wts: if args.single_tasks: args.exp_name = f'{args.exp_name}_single_tasks' # args.sparse_wt = sparse_wt if not os.path.exists('outputs'): os.mkdir('outputs') # Prepare dataloaders train_dataset, val_dataset = dataloaders.base.__dict__[args.dataset]( args.dataroot, args.train_aug) # 5 sequence tasks if args.dataset == 'multidataset': train_dataset_splits, val_dataset_splits, task_output_space = get_splits( train_dataset, val_dataset) else: if args.n_permutation > 0: train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen( train_dataset, val_dataset, args.n_permutation, remap_class=not args.no_class_remap) else: train_dataset_splits, val_dataset_splits, task_output_space = SplitGen( train_dataset, val_dataset, first_split_sz=args.first_split_size, other_split_sz=args.other_split_size, rand_split=args.rand_split, remap_class=not args.no_class_remap) # Prepare the Agent (model) agent_config = { 'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay, 'schedule': args.schedule, 'model_type': args.model_type, 'model_name': args.model_name, 'model_weights': args.model_weights, 'out_dim': { 'All': args.force_out_dim } if args.force_out_dim > 0 else task_output_space, 'optimizer': args.optimizer, 'print_freq': args.print_freq, 'gpuid': args.gpuid, 'reg_coef': args.reg_coef, 'exp_name': args.exp_name, 'nuclear_weight': args.nuclear_weight, 'period': args.period, 'threshold_trp': args.threshold_trp, 'sparse_wt': args.sparse_wt, 'perp_wt': args.perp_wt, 'reg_type_svd': args.reg_type_svd, 'energy_sv': args.energy_sv, 'save_running_stats': args.save_running_stats, 'e_search': args.e_search, 'sp_wt_search': args.sp_wt_search, 'single_tasks': args.single_tasks, 'prev_sing': args.prev_sing, 'debug': args.debug, 'grow_network': args.grow_network, 'ind_models': args.ind_models, 'dataset': args.dataset } if args.ind_models: agent_config_lst = {} for key_, val in task_output_space.items(): agent_config_lst[key_] = copy.deepcopy(agent_config) agent_config_lst[key_]['out_dim'] = {key_: val} agent_config_lst[key_]['ind_models'] = True # import pdb; pdb.set_trace() agents_lst = {} for key_, val in task_output_space.items(): agents_lst[key_] = agents.__dict__[args.agent_type].__dict__[ args.agent_name](agent_config_lst[key_]) else: print(args.nuclear_weight, args.threshold_trp) agent = agents.__dict__[args.agent_type].__dict__[args.agent_name]( agent_config) # import pdb; pdb.set_trace() task_names = sorted(list(task_output_space.keys()), key=int) print('Task order:', task_names) if args.rand_split_order: shuffle(task_names) print('Shuffled task order:', task_names) acc_table = OrderedDict() if args.offline_training: # Non-incremental learning / offline_training / measure the upper-bound performance task_names = ['All'] train_dataset_all = torch.utils.data.ConcatDataset( train_dataset_splits.values()) val_dataset_all = torch.utils.data.ConcatDataset( val_dataset_splits.values()) train_loader = torch.utils.data.DataLoader(train_dataset_all, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) test_loader = torch.utils.data.DataLoader(val_dataset_all, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) agent.learn_batch(train_loader, test_loader) acc_table['All'] = {} acc_table['All']['All'] = agent.validation(test_loader) else: # Incremental learning save_dict = {} # adhering to the Advarsarial Continual learning paper if args.dataset == 'miniImageNet': validation_split = 0.02 else: validation_split = 0.15 final_accs = OrderedDict() if args.loadmodel: agent.load_model(args.loadmodel) for i in range(len(task_names)): train_name = task_names[i] test_loader = torch.utils.data.DataLoader( val_dataset_splits[train_name], batch_size=args.batch_size, shuffle=False, num_workers=args.workers) final_accs[train_name] = agent.validation( test_loader, train_name).avg print( f'/CumAcc/Task{train_name}, {final_accs[train_name]}, {i}') else: for i in range(len(task_names)): train_name = task_names[i] if args.ind_models: agent = agents_lst[train_name] writer = SummaryWriter(log_dir="runs/" + agent.exp_name) # # import pdb; pdb.set_trace() #print('Final split for ImageNet', int(np.floor(validation_split * len(train_dataset_splits[train_name])))) # split = int(np.floor(validation_split * len(train_dataset_splits[train_name]))) # train_split, val_split = torch.utils.data.random_split(train_dataset_splits[train_name], [len(train_dataset_splits[train_name]) - split, split]) # train_dataset_splits[train_name] = train_split print('====================== Task Num', i + 1, '=======================') print('======================', train_name, '=======================') train_loader = torch.utils.data.DataLoader( train_dataset_splits[train_name], batch_size=args.batch_size, shuffle=True, num_workers=args.workers) test_loader = torch.utils.data.DataLoader( val_dataset_splits[train_name], batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.incremental_class: agent.add_valid_output_dim(task_output_space[train_name]) # Learn # import pdb; pdb.set_trace() agent.learn_batch(train_loader, test_loader, task_name=train_name) # if single task skip this step: # Evaluate if not (args.single_tasks or args.ind_models): final_accs = OrderedDict() acc_table[train_name] = OrderedDict() for j in range(i + 1): # import pdb; pdb.set_trace() val_name = task_names[j] print('validation split name:', val_name) # import pdb; pdb.set_trace() val_data = val_dataset_splits[ val_name] if not args.eval_on_train_set else train_dataset_splits[ val_name] test_loader = torch.utils.data.DataLoader( val_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) acc_table[val_name][train_name] = agent.validation( test_loader, val_name) final_accs[val_name] = acc_table[val_name][ train_name].avg print( f'/CumAcc/Task{val_name}, {acc_table[val_name][train_name].avg}, {i}' ) writer.add_scalar('/CumAcc/Task' + val_name, acc_table[val_name][train_name].avg, i) # writer.add_scalar('/CumLoss/Task' + val_name, loss_table[val_name][train_name].avg, i ) elif args.single_tasks: val_name = task_names[i] final_accs[train_name] = agent.validation( test_loader, val_name).avg else: val_name = task_names[i] final_accs[train_name] = agent.validation( test_loader, val_name)[0].avg agent.save_model() print(final_accs) # collect the channels used, accuracies of individual task, compression ratio of the final model for the # compute the size of the model avg_acc = sum(list(final_accs.values())) / len(final_accs) # import pdb; pdb.set_trace() model_size = agent.mode_comp( agent.chann_used ) if not args.ind_models else len(agents_lst) * agent.mode_comp() save_dict['channels'] = agent.chann_used if not args.ind_models else [] # import pdb; pdb.set_trace() save_dict['acc'] = final_accs save_dict['avg_acc'] = avg_acc print('Average accuray is', avg_acc) print('Model size is', model_size) save_dict['all_rank'] = agent.all_rank if not args.ind_models else [] save_dict['model_size'] = model_size return save_dict, task_names
def run(args): if not os.path.exists('outputs'): os.mkdir('outputs') # Prepare dataloaders train_dataset, val_dataset = dataloaders.base.__dict__[args.dataset](args.dataroot, args.train_aug) if args.n_permutation>0: train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen(train_dataset, val_dataset, args.n_permutation, remap_class=not args.no_class_remap) else: train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(train_dataset, val_dataset, first_split_sz=args.first_split_size, other_split_sz=args.other_split_size, rand_split=args.rand_split, remap_class=not args.no_class_remap) # Prepare the Agent (model) agent_config = {'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay,'schedule': args.schedule, 'model_type':args.model_type, 'model_name': args.model_name, 'model_weights':args.model_weights, 'out_dim':{'All':args.force_out_dim} if args.force_out_dim > 0 else task_output_space, 'optimizer':args.optimizer, 'print_freq':args.print_freq, 'gpuid': args.gpuid, 'reg_coef':args.reg_coef, 'exp_name' : args.exp_name, 'warmup':args.warm_up, 'nesterov':args.nesterov, 'run_num' :args.run_num, 'freeze_core':args.freeze_core, 'reset_opt':args.reset_opt, 'noise_':args.noise_, 'add_extra_last':args.add_extra_last, 'batch_size':args.batch_size, 'reg': args.reg } agent = agents.__dict__[args.agent_type].__dict__[args.agent_name](agent_config) print(agent.model) print('#parameter of model:',agent.count_parameter()) # Decide split ordering task_names = sorted(list(task_output_space.keys()), key=int) print('Task order:',task_names) #import pdb; pdb.set_trace() if args.rand_split_order: shuffle(task_names) print('Shuffled task order:', task_names) acc_table = OrderedDict() loss_table = OrderedDict() if args.offline_training: # Non-incremental learning / offline_training / measure the upper-bound performance task_names = ['All'] train_dataset_all = torch.utils.data.ConcatDataset(train_dataset_splits.values()) val_dataset_all = torch.utils.data.ConcatDataset(val_dataset_splits.values()) train_loader = torch.utils.data.DataLoader(train_dataset_all, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset_all, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) epochs = args.epochs[-1] agent.learn_batch(train_loader, val_loader, [0, epochs]) acc_table['All'] = {} loss_table['All'] = {} acc_table['All']['All'], loss_table['All']['All'] = agent.validation(val_loader) else: # Incremental learning # Feed data to agent and evaluate agent's performance for i in range(1): train_name = task_names[i] print('======================',train_name,'=======================') train_loader = torch.utils.data.DataLoader(train_dataset_splits[train_name], batch_size=args.batch_size, shuffle=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset_splits[train_name], batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.incremental_class: agent.add_valid_output_dim(task_output_space[train_name]) # Learn epochs = args.epochs[i] if len(args.epochs) - 1 else args.epochs[0] # split the epochs into multiple sub epochs # to perform validation after every 10 in between if epochs are 80 --> 20, 30, 40, 50, 40 , 80 # helps us in better understanding how the degradation happens for epoch_10 in range(int(epochs / args.old_val_freq)): agent.learn_batch(train_loader, val_loader, epochs=[epoch_10 * args.old_val_freq, (epoch_10 + 1)* args.old_val_freq], task_n=train_name) # Evaluate acc_table[train_name] = OrderedDict() loss_table[train_name] = OrderedDict() writer = SummaryWriter(log_dir="runs/" + agent.exp_name) for j in range(i+1): if i != j: continue val_name = task_names[j] print('validation split name:', val_name) val_data = val_dataset_splits[val_name] if not args.eval_on_train_set else train_dataset_splits[val_name] val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) acc_table[val_name][train_name], loss_table[val_name][train_name] = agent.validation(val_loader, val_name) # tensorboard # agent.writer.reopen() print('logging for Task {} while training {}'.format(val_name, train_name)) print('logging', int(train_name) + (epoch_10 + 1) * 0.1 ) writer.add_scalar('Run' + str(args.run_num) + '/CumAcc/Task' + val_name, acc_table[val_name][train_name].avg, float(int(train_name)) * 100 + (epoch_10 + 1) * args.old_val_freq) writer.add_scalar('Run' + str(args.run_num) + '/CumLoss/Task' + val_name, loss_table[val_name][train_name].avg, int(train_name) * 100 + (epoch_10 + 1)* args.old_val_freq ) writer.close() # if i == 1: #after the first task freeze some weights: # def funcname(self, parameter_list): # npimg = img.numpy().transpose((1,2,0)) # min_val = np.min(npimg, keepdims =True) # print('min',min_val) # max_val = np.max(npimg, keepdims =True) # print('max',max_val) # inp = (npimg-min_val)/(max_val-min_val) # # inp = npimg # plt.imshow(inp) # pass return acc_table, task_names
def run(args): if not os.path.exists('outputs'): os.mkdir('outputs') # Prepare dataloaders # train_dataset, val_dataset = dataloaders.base.__dict__[args.dataset](args.dataroot, args.train_aug) train_dataset, val_dataset = factory('dataloaders', 'base', args.dataset)(args.dataroot, args.train_aug) if args.n_permutation > 0: train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen( train_dataset, val_dataset, args.n_permutation, remap_class=not args.no_class_remap) else: train_dataset_splits, val_dataset_splits, task_output_space = SplitGen( train_dataset, val_dataset, first_split_sz=args.first_split_size, other_split_sz=args.other_split_size, rand_split=args.rand_split, remap_class=not args.no_class_remap) # Prepare the Agent (model) dataset_name = args.dataset + \ '_{}'.format(args.first_split_size) + \ '_{}'.format(args.other_split_size) agent_config = { 'model_lr': args.model_lr, 'momentum': args.momentum, 'model_weight_decay': args.model_weight_decay, 'schedule': args.schedule, 'model_type': args.model_type, 'model_name': args.model_name, 'model_weights': args.model_weights, 'out_dim': { 'All': args.force_out_dim } if args.force_out_dim > 0 else task_output_space, 'model_optimizer': args.model_optimizer, 'print_freq': args.print_freq, 'gpu': True if args.gpuid[0] >= 0 else False, 'with_head': args.with_head, 'reset_model_opt': args.reset_model_opt, 'reg_coef': args.reg_coef, 'head_lr': args.head_lr, 'svd_lr': args.svd_lr, 'bn_lr': args.bn_lr, 'svd_thres': args.svd_thres, 'gamma': args.gamma, 'dataset_name': dataset_name } # agent = agents.__dict__[args.agent_type].__dict__[args.agent_name](agent_config) agent = factory('svd_agent', args.agent_type, args.agent_name)(agent_config) # Decide split ordering task_names = sorted(list(task_output_space.keys()), key=int) print('Task order:', task_names) if args.rand_split_order: shuffle(task_names) print('Shuffled task order:', task_names) # task_names = ['2', '1', '3', '4', '5'] acc_table = OrderedDict() acc_table_train = OrderedDict() if args.offline_training: # Non-incremental learning / offline_training / measure the upper-bound performance task_names = ['All'] train_dataset_all = torch.utils.data.ConcatDataset( train_dataset_splits.values()) val_dataset_all = torch.utils.data.ConcatDataset( val_dataset_splits.values()) train_loader = torch.utils.data.DataLoader(train_dataset_all, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset_all, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) agent.learn_batch(train_loader, val_loader) acc_table['All'] = {} acc_table['All']['All'] = agent.validation(val_loader) else: # Incremental learning # Feed data to agent and evaluate agent's performance for i in range(len(task_names)): train_name = task_names[i] print('======================', train_name, '=======================') train_loader = torch.utils.data.DataLoader( train_dataset_splits[train_name], batch_size=args.batch_size, shuffle=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader( val_dataset_splits[train_name], batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.incremental_class: agent.add_valid_output_dim(task_output_space[train_name]) # Learn agent.train_task(train_loader, val_loader) torch.cuda.empty_cache() # Evaluate acc_table[train_name] = OrderedDict() acc_table_train[train_name] = OrderedDict() for j in range(i + 1): val_name = task_names[j] print('validation split name:', val_name) val_data = val_dataset_splits[ val_name] if not args.eval_on_train_set else train_dataset_splits[ val_name] val_loader = torch.utils.data.DataLoader( val_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) acc_table[val_name][train_name] = agent.validation(val_loader) print("**************************************************") print('training split name:', val_name) train_data = train_dataset_splits[ val_name] if not args.eval_on_train_set else train_dataset_splits[ val_name] train_loader = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) acc_table_train[val_name][train_name] = agent.validation( train_loader) print("**************************************************") return acc_table, task_names
def prepare_dataloaders(args): # Prepare dataloaders Dataset = dataloaders.base.__dict__[args.dataset] # SPLIT CUB if args.is_split_cub: print("running split -------------") from dataloaders.cub import CUB Dataset = CUB if args.train_aug: print("train aug not supported for cub") return train_dataset, val_dataset = Dataset(args.dataroot) train_dataset_splits, val_dataset_splits, task_output_space = SplitGen( train_dataset, val_dataset, first_split_sz=args.first_split_size, other_split_sz=args.other_split_size, rand_split=args.rand_split, remap_class=not args.no_class_remap) n_tasks = len(task_output_space.items()) # Permuted MNIST elif args.n_permutation > 0: # TODO : CHECK subset_size train_dataset, val_dataset = Dataset(args.dataroot, args.train_aug, angle=0, subset_size=args.subset_size) print("Working with permuatations :) ") train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen( train_dataset, val_dataset, args.n_permutation, remap_class=not args.no_class_remap) n_tasks = args.n_permutation # Rotated MNIST elif args.n_rotate > 0 or len(args.rotations) > 0: # TODO : Check subset size train_dataset_splits, val_dataset_splits, task_output_space = RotatedGen( Dataset=Dataset, dataroot=args.dataroot, train_aug=args.train_aug, n_rotate=args.n_rotate, rotate_step=args.rotate_step, remap_class=not args.no_class_remap, rotations=args.rotations, subset_size=args.subset_size) n_tasks = len(task_output_space.items()) # Split MNIST else: print("running split -------------") # TODO : Check subset size train_dataset, val_dataset = Dataset(args.dataroot, args.train_aug, angle=0, subset_size=args.subset_size) train_dataset_splits, val_dataset_splits, task_output_space = SplitGen( train_dataset, val_dataset, first_split_sz=args.first_split_size, other_split_sz=args.other_split_size, rand_split=args.rand_split, remap_class=not args.no_class_remap) n_tasks = len(task_output_space.items()) print(f"task_output_space {task_output_space}") return task_output_space, n_tasks, train_dataset_splits, val_dataset_splits
from tqdm import tqdm dataroot = "/tmp/datasets" first_split_size = 5 other_split_size = 5 rand_split = False no_class_remap = False # from dataloaders.base import MNIST # Dataset = MNIST Dataset = CUB train_dataset, val_dataset = Dataset(dataroot) train_dataset_splits, val_dataset_splits, task_output_space = SplitGen( train_dataset, val_dataset, first_split_sz=first_split_size, other_split_sz=other_split_size, rand_split=rand_split, remap_class=no_class_remap) # __get__ returns img, target, self.name through AppendName n_tasks = len(task_output_space.items()) task_names = sorted(list(task_output_space.keys()), key=int) batch_size = 32 workers = 0 for i in tqdm(range(len(task_names)), "task"): task_name = task_names[i] print('======================', task_name, '=======================') train_loader = torch.utils.data.DataLoader(