def __init__(self, base, optimizer, input_size, cfg, goal): super().__init__(base, optimizer, input_size, cfg, goal) # self.reservoir = {'x': np.zeros((self.mem_limit, input_size)), # 'y': [None] * self.mem_limit, # 'y_extra': [None] * self.mem_limit, # 'x_origin': np.zeros((self.mem_limit, input_size)), # 'x_edit_state': [None] * self.mem_limit, # 'loss_stats': [None] * self.mem_limit, # 'loss_stat_steps': [None] * self.mem_limit, # 'forget': [None] * self.mem_limit, # 'support': [None] * self.mem_limit # } self.itf_cnt = 0 self.total_cnt = 0 self.grad_iter = get_config_attr(cfg, 'EXTERNAL.OCL.GRAD_ITER', default=1) self.grad_stride = get_config_attr(cfg, 'EXTERNAL.OCL.GRAD_STRIDE', default=10.) self.edit_decay = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_DECAY', default=0.) self.no_write_back = get_config_attr(cfg, 'EXTERNAL.OCL.NO_WRITE_BACK', default=0) self.reservoir['age'] = np.zeros(self.mem_limit)
def get_rotated_mnist_dataloader(cfg, split='train', filter_obj=None, batch_size=128, task_num=10, *args, **kwargs): d = DotDict() fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True) global _rmnist_loaders if not _rmnist_loaders: data = get_rotated_mnist(d) #train_loader, val_loader, test_loader = [CLDataLoader(elem, batch_size, train=t) \ # for elem, t in zip(data, [True, False, False])] loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \ for elem, t in zip(data, [True, False, False])] _rmnist_loaders = train_loader, val_loader, test_loader else: train_loader, val_loader, test_loader = _rmnist_loaders if split == 'train': return train_loader[filter_obj[0]] elif split == 'val': return val_loader[filter_obj[0]] elif split == 'test': return test_loader[filter_obj[0]]
def main(args): if '%id' in args.name: exp_name = args.name.replace('%id', get_exp_id()) else: exp_name = args.name combined_cfg = CfgNode(new_allowed=True) combined_cfg.merge_from_file(args.config) cfg = combined_cfg cfg.EXTERNAL.EXPERIMENT_NAME = exp_name cfg.SEED = args.seed cfg.DEBUG = args.debug set_cfg_from_args(args, cfg) output_dir = get_config_attr(cfg, 'OUTPUT_DIR', default='') if output_dir == '.': output_dir = 'runs/' cfg.OUTPUT_DIR = os.path.join(output_dir, '{}_{}'.format(cfg.EXTERNAL.EXPERIMENT_NAME, cfg.SEED)) cfg.MODE = 'train' # cfg.freeze() num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 distributed = num_gpus > 1 local_rank = int(os.environ.get('LOCAL_RANK', 0)) if distributed: torch.cuda.set_device(local_rank) torch.distributed.init_process_group( backend="nccl", init_method="env://" ) output_dir = cfg.OUTPUT_DIR # save overloaded model config in the output directory model = train(cfg, local_rank, distributed, tune=args.tune) output_args_path = os.path.join(output_dir, 'args.txt') wf = open(output_args_path, 'w') wf.write(' '.join(sys.argv)) wf.close()
def get_split_mini_imagenet_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs): global _cache_mini_imagenet d = DotDict() fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True) if not _cache_mini_imagenet: data = get_miniimagenet(d) loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \ for elem, t in zip(data, [True, False, False])] _cache_mini_imagenet = train_loader, val_loader, test_loader train_loader, val_loader, test_loader = _cache_mini_imagenet if split == 'train': return train_loader[filter_obj[0]] elif split == 'val': return val_loader[filter_obj[0]] elif split == 'test': return test_loader[filter_obj[0]]
def __init__(self, base, optimizer, input_size, cfg, goal, **kwargs): super().__init__() self.net = base self.optimizer = optimizer self.input_size = input_size self.cfg = cfg self.goal = goal if 'caption' in self.goal: self.clip_grad = True self.use_image_feat = get_config_attr(self.cfg, 'EXTERNAL.USE_IMAGE_FEAT', default=0) self.spatial_feat_shape = (2048, 7, 7) self.bbox_feat_shape = (100, 2048) self.bbox_shape = (100, 4) self.rev_optim = None if hasattr(self.net, 'rev_update_modules'): self.rev_optim = optim.Adam( lr=cfg.SOLVER.BASE_LR, betas=(0.9, 0.999), params=self.net.rev_update_modules.parameters())
def get_split_cifar100_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs): d = DotDict() fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True) global _cache_cifar100 if not _cache_cifar100: data = get_split_cifar100( d, cfg ) #ds_cifar10and100(batch_size=batch_size, num_workers=0, cfg=cfg, **kwargs) loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \ for elem, t in zip(data, [True, False, False])] _cache_cifar100 = train_loader, val_loader, test_loader train_loader, val_loader, test_loader = _cache_cifar100 if split == 'train': return train_loader[filter_obj[0]] elif split == 'val': return val_loader[filter_obj[0]] elif split == 'test': return test_loader[filter_obj[0]]
def train(cfg, local_rank, distributed, tune=False): is_ocl = hasattr(cfg.EXTERNAL.OCL, 'ALGO') and cfg.EXTERNAL.OCL.ALGO != 'PLAIN' task_incremental = get_config_attr(cfg, 'EXTERNAL.OCL.TASK_INCREMENTAL', default=False) cfg.TUNE = tune algo = cfg.EXTERNAL.OCL.ALGO if hasattr(cfg,'MNIST'): if cfg.MNIST.TASK == 'split': goal = 'split_mnist' elif cfg.MNIST.TASK == 'permute': goal = 'permute_mnist' elif cfg.MNIST.TASK == 'rotate': goal = 'rotated_mnist' if hasattr(cfg, 'CIFAR'): goal = 'split_cifar' if get_config_attr(cfg, 'CIFAR.DATASET', default="") == 'CIFAR100': goal = 'split_cifar100' if get_config_attr(cfg, 'CIFAR.MINI_IMAGENET', default=0): goal = 'split_mini_imagenet' if hasattr(cfg,'MNIST'): num_of_datasets = 1 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int) num_of_classes = 10 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.CLASS_NUM', totype=int) base_model = mnist_simple_net_400width_classlearning_1024input_10cls_1ds(num_of_datasets=num_of_datasets, num_of_classes=num_of_classes, task_incremental=task_incremental) base_model.cfg = cfg elif hasattr(cfg, 'CIFAR'): if goal == 'split_cifar': num_of_datasets = 1 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int) num_of_classes = 10 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.CLASS_NUM', totype=int) elif goal == 'split_cifar100': num_of_datasets = 1 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int) num_of_classes = 100 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.CLASS_NUM', totype=int) elif goal == 'split_mini_imagenet': num_of_datasets = 1 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int) num_of_classes = 100 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.CLASS_NUM', totype=int) base_model = ResNetClassifier(cfg, depth='18', mlp=1, ignore_index=-100, num_of_datasets=num_of_datasets, num_of_classes=num_of_classes, task_incremental=task_incremental, goal=goal) base_model.cfg = cfg else: base_model = ResNetClassifier(cfg) device = torch.device(cfg.MODEL.DEVICE) base_model.to(device) if cfg.EXTERNAL.OPTIMIZER.ADAM: optimizer = torch.optim.Adam( filter(lambda x: x.requires_grad, base_model.parameters()), lr=cfg.SOLVER.BASE_LR, betas=(0.9, 0.999) ) else: optimizer = torch.optim.SGD( filter(lambda x: x.requires_grad, base_model.parameters()), lr=cfg.SOLVER.BASE_LR ) # algorithm specific model wrapper x_size = 3 * 2 * base_model.cfg.EXTERNAL.IMAGE_SIZE ** 2 if goal == 'classification' else \ 3 * base_model.cfg.EXTERNAL.IMAGE_SIZE ** 2 if goal == 'split_mnist' or goal == 'permute_mnist' or goal == 'rotated_mnist': x_size = 28 * 28 if goal == 'split_cifar' or goal == 'split_cifar100': x_size = 3 * 32 * 32 if goal == 'split_mini_imagenet': x_size = 3 * 84 * 84 if algo == 'ER': model = ExperienceReplay(base_model, optimizer, x_size, base_model.cfg, goal) elif algo == 'VERX': model = ExperienceEvolveApprox(base_model, optimizer, x_size, base_model.cfg, goal) elif algo == 'AGEM': model = AGEM(base_model, optimizer, x_size, base_model.cfg, goal) elif algo == 'naive': model = NaiveWrapper(base_model, optimizer, x_size, base_model.cfg, goal) model.to(device) use_mixed_precision = cfg.DTYPE == "float16" arguments = {"iteration": 0, "global_step": 0, "epoch": 0} output_dir = cfg.OUTPUT_DIR writer = None epoch_num = 1 for e in range(epoch_num): print("epoch") arguments['iteration'] = 0 epoch = arguments['epoch'] if goal == 'split_mnist' or goal == 'permute_mnist' or goal == 'rotated_mnist': ocl_train_mnist(model, optimizer, None, device, arguments, writer, epoch, goal, tune=tune) elif goal == 'split_cifar' or goal == 'split_cifar100' or goal == 'split_mini_imagenet': ocl_train_cifar(model, optimizer, None, device, arguments, writer, epoch, goal, tune=tune) else: raise NotImplementedError arguments['epoch'] += 1 with open(os.path.join(output_dir, 'model.bin'),'wb') as wf: torch.save(model.state_dict(), wf) # else: # break if is_ocl and hasattr(model, 'dump_reservoir') and args.dump_reservoir: model.dump_reservoir(os.path.join(cfg.OUTPUT_DIR, 'mem_dump.pkl'), verbose=args.dump_reservoir_verbose) return model
def __init__(self, base, optimizer, input_size, cfg, goal): super(AGEM, self).__init__(base, optimizer, input_size, cfg, goal) # self.grads = self.grads.cuda() self.violation_count = 0 self.agem_k = get_config_attr(cfg, 'EXTERNAL.OCL.AGEM_K', default=256)
def __init__(self, base, optimizer, input_size, cfg, goal): super().__init__(base, optimizer, input_size, cfg, goal) self.net = base self.optimizer = optimizer self.mem_limit = cfg.EXTERNAL.REPLAY.MEM_LIMIT self.mem_bs = cfg.EXTERNAL.REPLAY.MEM_BS self.input_size = input_size self.reservoir, self.example_seen = None, None self.reset_mem() self.mem_occupied = {} self.seen_tasks = [] self.balanced = False self.policy = get_config_attr(cfg, 'EXTERNAL.OCL.POLICY', default='reservoir', totype=str) self.mir_k = get_config_attr(cfg, 'EXTERNAL.REPLAY.MIR_K', default=10, totype=int) self.mir = get_config_attr(cfg, 'EXTERNAL.OCL.MIR', default=0, totype=int) self.mir_agg = get_config_attr(cfg, 'EXTERNAL.OCL.MIR_AGG', default='avg', totype=str) self.concat_replay = get_config_attr(cfg, 'EXTERNAL.OCL.CONCAT', default=0, totype=int) self.separate_replay = get_config_attr(cfg, 'EXTERNAL.OCL.SEPARATE', default=0, totype=int) self.mem_augment = get_config_attr(cfg, 'EXTERNAL.OCL.MEM_AUG', default=0, totype=int) self.legacy_aug = get_config_attr(cfg, 'EXTERNAL.OCL.LEGACY_AUG', default=0, totype=int) self.use_hflip_aug = get_config_attr(cfg, 'EXTERNAL.OCL.USE_HFLIP_AUG', default=1, totype=int) self.padding_aug = get_config_attr(cfg, 'EXTERNAL.OCL.PADDING_AUG', default=-1, totype=int) self.rot_aug = get_config_attr(cfg, 'EXTERNAL.OCL.ROT_AUG', default=-1, totype=int) self.lb_reservoir = get_config_attr(cfg, 'EXTERNAL.OCL.LB_RESERVOIR', default=0) self.cfg = cfg self.grad_dims = [] for param in self.parameters(): self.grad_dims.append(param.data.numel())
def sample_mem_batch(self, device, return_indices=False, k=None, seed=1, mir=False, input_x=None, input_y=None, input_task_ids=None, mir_k=0, skip_task=None, mir_least=False): random_state = self.get_random(seed) if k is None: k = self.mem_bs if not self.balanced: # reservoir n_max = min(self.mem_limit, self.example_seen) available_indices = [_ for _ in range(n_max)] if skip_task is not None and get_config_attr( self.cfg, 'EXTERNAL.REPLAY.FILTER_SELF', default=0, mute=True): available_indices = list( filter(lambda x: self.reservoir['y_extra'][x] != skip_task, available_indices)) if not available_indices: if return_indices: return None, None, None else: return None, None, None elif len(available_indices) < k: indices = np.arange(n_max) else: indices = random_state.choice(available_indices, k, replace=False) else: available_index = self.get_available_index() if len(available_index) == 0: if return_indices: return None, None, None else: return None, None, None elif len(available_index) < k: indices = np.array(available_index) else: indices = random_state.choice(available_index, k, replace=False) x = self.reservoir['x'][indices] x = torch.from_numpy(x).to(device).float() y = index_select(self.reservoir['y'], indices, device) # [ [...], [...] ] y_extra = index_select(self.reservoir['y_extra'], indices, device) if type(y[0]) not in [list, tuple]: y_pad = concat_with_padding(y) else: y_pad = [torch.stack(_).to(device) for _ in zip(*y)] y_extra = concat_with_padding(y_extra) if mir: x, y_pad, y_extra, indices = self.decide_mir_mem( input_x, input_y, input_task_ids, mir_k, x, y, y_extra, indices, mir_least) if not return_indices: return x, y_pad, y_extra else: return (x, indices), y_pad, y_extra
def ocl_train_mnist(model, optimizer, checkpointer, device, arguments, writer, epoch, goal='split', tune=False): logger = logging.getLogger("maskrcnn_benchmark.trainer") logger.info("Start training @ epoch {:02d}".format(arguments['epoch'])) model.train() cfg = model.cfg pbar = tqdm(position=0, desc='GPU: 0') num_instances = cfg.MNIST.INSTANCE_NUM if goal == 'split_mnist': task_num = 5 loader_func = get_split_mnist_dataloader elif goal == 'permute_mnist': task_num = 10 loader_func = get_permute_mnist_dataloader elif goal == 'rotated_mnist': task_num = 20 loader_func = get_rotated_mnist_dataloader else: raise ValueError if tune: task_num = get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int, default=3) num_epoch = get_config_attr(cfg, 'EXTERNAL.EPOCH', totype=int, default=1) total_step = task_num * 1000 base_lr = get_config_attr(cfg, 'SOLVER.BASE_LR', totype=float) # whether iid iid = not get_config_attr(cfg, 'EXTERNAL.OCL.ACTIVATED', totype=bool) do_exp_lr_decay = get_config_attr(cfg, 'EXTERNAL.OCL.EXP_LR_DECAY', 0) all_accs = [] best_avg_accs = [] step = 0 for task_id in range(task_num): if iid: if task_id != 0: break data_loaders = [ loader_func(cfg, 'train', [task_id], batch_size=cfg.EXTERNAL.BATCH_SIZE, max_instance=num_instances) for task_id in range(task_num) ] data_loader = DataLoader(IIDDataset(data_loaders), batch_size=cfg.EXTERNAL.BATCH_SIZE) num_instances *= task_num else: data_loader = loader_func(cfg, 'train', [task_id], batch_size=cfg.EXTERNAL.BATCH_SIZE, max_instance=num_instances) best_avg_acc = -1 #model.net.set_task(task_id) # choose the classifier head if the model supports for epoch in range(num_epoch): seen = 0 for i, data in enumerate(data_loader): if seen >= num_instances: break inputs, labels = data inputs, labels = (inputs.to(device), labels.to(device)) task_ids = torch.LongTensor([task_id] * labels.size(0)).to( inputs.device) inputs = inputs.flatten(1) model.observe(inputs, labels, task_ids=task_ids) step += 1 if do_exp_lr_decay: exp_decay_lr(optimizer, step, total_step, base_lr) seen += labels.size(0) # run evaluation with torch.no_grad(): if iid: accs, _, avg_acc = inference_mnist(model, task_num, loader_func, device, tune=tune) else: accs, _, avg_acc = inference_mnist(model, task_id + 1, loader_func, device, tune=tune) logger.info('Epoch {}\tTask {}\tAcc {}'.format( epoch, task_id, avg_acc)) for i, acc in enumerate(accs): logger.info('::Val Task {}\t Acc {}'.format(i, acc)) all_accs.append(accs) if avg_acc > best_avg_acc: best_avg_acc = avg_acc else: break best_avg_accs.append(best_avg_acc) file_name = 'result.json' if not tune else 'result_tune_k{}.json'.format( task_num) result_file = open(os.path.join(cfg.OUTPUT_DIR, file_name), 'w') json.dump({ 'all_accs': all_accs, 'avg_acc': avg_acc }, result_file, indent=4) result_file.close()
def ocl_train_cifar(model, optimizer, checkpointer, device, arguments, writer, epoch, goal='split_cifar', tune=False): logger = logging.getLogger("maskrcnn_benchmark.trainer") logger.info("Start training @ epoch {:02d}".format(arguments['epoch'])) model.train() cfg = model.cfg num_epoch = cfg.CIFAR.EPOCH if goal == 'split_cifar': loader_func = get_split_cifar_dataloader total_step = 4750 elif goal == 'split_cifar100': loader_func = get_split_cifar100_dataloader total_step = 25000 else: loader_func = get_split_mini_imagenet_dataloader total_step = 22500 max_instance = cfg.CIFAR.INSTANCE_NUM if hasattr(cfg.CIFAR, 'INSTANCE_NUM') else 1e10 if not tune: task_num = get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int) else: task_num = get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int) do_exp_lr_decay = get_config_attr(cfg, 'EXTERNAL.OCL.EXP_LR_DECAY', 0) base_lr = get_config_attr(cfg, 'SOLVER.BASE_LR', totype=float) step = 0 num_epoch = get_config_attr(cfg, 'EXTERNAL.EPOCH', totype=int, default=1) all_accs = [] best_avg_accs = [] iid = not get_config_attr(cfg, 'EXTERNAL.OCL.ACTIVATED', totype=bool) for task_id in range(task_num): if iid: if task_id != 0: break data_loaders = [ loader_func(cfg, 'train', [task_id], batch_size=cfg.EXTERNAL.BATCH_SIZE, max_instance=max_instance) for task_id in range(task_num) ] data_loader = DataLoader(IIDDataset(data_loaders), batch_size=cfg.EXTERNAL.BATCH_SIZE) max_instance *= task_num else: data_loader = loader_func(cfg, 'train', [task_id], batch_size=cfg.EXTERNAL.BATCH_SIZE, max_instance=max_instance) pbar = tqdm(position=0, desc='GPU: 0', total=len(data_loader)) best_avg_acc = -1 for epoch in range(num_epoch): seen = 0 for i, data in enumerate(data_loader): if seen >= max_instance: break pbar.update(1) inputs, labels = data inputs, labels = (inputs.to(device), labels.to(device)) inputs = inputs.flatten(1) task_ids = torch.LongTensor([task_id] * labels.size(0)).to( inputs.device) model.observe(inputs, labels, task_ids) seen += inputs.size(0) if do_exp_lr_decay: exp_decay_lr(optimizer, step, total_step, base_lr) step += 1 # # run evaluation with torch.no_grad(): if iid: accs, _, avg_acc = inference_cifar(model, task_num, loader_func, device, goal, tune=tune) else: accs, _, avg_acc = inference_cifar(model, task_id + 1, loader_func, device, goal, tune=tune) logger.info('Epoch {}\tTask {}\tAcc {}'.format( epoch, task_id, avg_acc)) for i, acc in enumerate(accs): logger.info('::Val Task {}\t Acc {}'.format(i, acc)) all_accs.append(accs) if avg_acc > best_avg_acc: best_avg_acc = avg_acc else: break best_avg_accs.append(best_avg_acc) file_name = 'result.json' if not tune else 'result_tune_k{}.json'.format( task_num) result_file = open(os.path.join(cfg.OUTPUT_DIR, file_name), 'w') json.dump( { 'all_accs': all_accs, 'avg_acc': avg_acc, 'best_avg_accs': best_avg_accs }, result_file, indent=4) result_file.close() return best_avg_accs