コード例 #1
0
 def __init__(self, base, optimizer, input_size, cfg, goal):
     super().__init__(base, optimizer, input_size, cfg, goal)
     # self.reservoir = {'x': np.zeros((self.mem_limit, input_size)),
     #                   'y': [None] * self.mem_limit,
     #                   'y_extra': [None] * self.mem_limit,
     #                   'x_origin': np.zeros((self.mem_limit, input_size)),
     #                   'x_edit_state': [None] * self.mem_limit,
     #                   'loss_stats': [None] * self.mem_limit,
     #                   'loss_stat_steps': [None] * self.mem_limit,
     #                   'forget': [None] * self.mem_limit,
     #                   'support': [None] * self.mem_limit
     #                   }
     self.itf_cnt = 0
     self.total_cnt = 0
     self.grad_iter = get_config_attr(cfg,
                                      'EXTERNAL.OCL.GRAD_ITER',
                                      default=1)
     self.grad_stride = get_config_attr(cfg,
                                        'EXTERNAL.OCL.GRAD_STRIDE',
                                        default=10.)
     self.edit_decay = get_config_attr(cfg,
                                       'EXTERNAL.OCL.EDIT_DECAY',
                                       default=0.)
     self.no_write_back = get_config_attr(cfg,
                                          'EXTERNAL.OCL.NO_WRITE_BACK',
                                          default=0)
     self.reservoir['age'] = np.zeros(self.mem_limit)
コード例 #2
0
ファイル: dataloader.py プロジェクト: INK-USC/GMED
def get_rotated_mnist_dataloader(cfg,
                                 split='train',
                                 filter_obj=None,
                                 batch_size=128,
                                 task_num=10,
                                 *args,
                                 **kwargs):
    d = DotDict()
    fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True)
    global _rmnist_loaders
    if not _rmnist_loaders:
        data = get_rotated_mnist(d)
        #train_loader, val_loader, test_loader = [CLDataLoader(elem, batch_size, train=t) \
        #                                         for elem, t in zip(data, [True, False, False])]
        loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
        train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
                                                 for elem, t in zip(data, [True, False, False])]
        _rmnist_loaders = train_loader, val_loader, test_loader
    else:
        train_loader, val_loader, test_loader = _rmnist_loaders
    if split == 'train':
        return train_loader[filter_obj[0]]
    elif split == 'val':
        return val_loader[filter_obj[0]]
    elif split == 'test':
        return test_loader[filter_obj[0]]
コード例 #3
0
def main(args):
    if '%id' in args.name:
        exp_name = args.name.replace('%id', get_exp_id())
    else:
        exp_name = args.name

    combined_cfg = CfgNode(new_allowed=True)
    combined_cfg.merge_from_file(args.config)
    cfg = combined_cfg
    cfg.EXTERNAL.EXPERIMENT_NAME = exp_name
    cfg.SEED = args.seed
    cfg.DEBUG = args.debug

    set_cfg_from_args(args, cfg)

    output_dir = get_config_attr(cfg, 'OUTPUT_DIR', default='')
    if output_dir == '.': output_dir = 'runs/'
    cfg.OUTPUT_DIR = os.path.join(output_dir,
                                  '{}_{}'.format(cfg.EXTERNAL.EXPERIMENT_NAME, cfg.SEED))
    cfg.MODE = 'train'

    # cfg.freeze()

    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    distributed = num_gpus > 1
    local_rank = int(os.environ.get('LOCAL_RANK', 0))

    if distributed:
        torch.cuda.set_device(local_rank)
        torch.distributed.init_process_group(
            backend="nccl", init_method="env://"
        )

    output_dir = cfg.OUTPUT_DIR

    # save overloaded model config in the output directory
    model = train(cfg, local_rank, distributed, tune=args.tune)

    output_args_path = os.path.join(output_dir, 'args.txt')
    wf = open(output_args_path, 'w')
    wf.write(' '.join(sys.argv))
    wf.close()
コード例 #4
0
ファイル: dataloader.py プロジェクト: INK-USC/GMED
def get_split_mini_imagenet_dataloader(cfg,
                                       split='train',
                                       filter_obj=None,
                                       batch_size=128,
                                       *args,
                                       **kwargs):
    global _cache_mini_imagenet
    d = DotDict()
    fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True)
    if not _cache_mini_imagenet:
        data = get_miniimagenet(d)
        loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
        train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
                                                 for elem, t in zip(data, [True, False, False])]
        _cache_mini_imagenet = train_loader, val_loader, test_loader
    train_loader, val_loader, test_loader = _cache_mini_imagenet
    if split == 'train':
        return train_loader[filter_obj[0]]
    elif split == 'val':
        return val_loader[filter_obj[0]]
    elif split == 'test':
        return test_loader[filter_obj[0]]
コード例 #5
0
    def __init__(self, base, optimizer, input_size, cfg, goal, **kwargs):
        super().__init__()
        self.net = base
        self.optimizer = optimizer
        self.input_size = input_size
        self.cfg = cfg
        self.goal = goal

        if 'caption' in self.goal:
            self.clip_grad = True
            self.use_image_feat = get_config_attr(self.cfg,
                                                  'EXTERNAL.USE_IMAGE_FEAT',
                                                  default=0)
            self.spatial_feat_shape = (2048, 7, 7)
            self.bbox_feat_shape = (100, 2048)
            self.bbox_shape = (100, 4)
        self.rev_optim = None
        if hasattr(self.net, 'rev_update_modules'):
            self.rev_optim = optim.Adam(
                lr=cfg.SOLVER.BASE_LR,
                betas=(0.9, 0.999),
                params=self.net.rev_update_modules.parameters())
コード例 #6
0
ファイル: dataloader.py プロジェクト: INK-USC/GMED
def get_split_cifar100_dataloader(cfg,
                                  split='train',
                                  filter_obj=None,
                                  batch_size=128,
                                  *args,
                                  **kwargs):
    d = DotDict()
    fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True)
    global _cache_cifar100
    if not _cache_cifar100:
        data = get_split_cifar100(
            d, cfg
        )  #ds_cifar10and100(batch_size=batch_size, num_workers=0, cfg=cfg, **kwargs)
        loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
        train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
                                                 for elem, t in zip(data, [True, False, False])]
        _cache_cifar100 = train_loader, val_loader, test_loader
    train_loader, val_loader, test_loader = _cache_cifar100
    if split == 'train':
        return train_loader[filter_obj[0]]
    elif split == 'val':
        return val_loader[filter_obj[0]]
    elif split == 'test':
        return test_loader[filter_obj[0]]
コード例 #7
0
def train(cfg, local_rank, distributed, tune=False):
    is_ocl = hasattr(cfg.EXTERNAL.OCL, 'ALGO') and cfg.EXTERNAL.OCL.ALGO != 'PLAIN'
    task_incremental = get_config_attr(cfg, 'EXTERNAL.OCL.TASK_INCREMENTAL', default=False)

    cfg.TUNE = tune

    algo = cfg.EXTERNAL.OCL.ALGO
    if hasattr(cfg,'MNIST'):
        if cfg.MNIST.TASK == 'split':
            goal = 'split_mnist'
        elif cfg.MNIST.TASK == 'permute':
            goal = 'permute_mnist'
        elif cfg.MNIST.TASK == 'rotate':
            goal = 'rotated_mnist'
    if hasattr(cfg, 'CIFAR'):
        goal = 'split_cifar'
        if get_config_attr(cfg, 'CIFAR.DATASET', default="") == 'CIFAR100':
            goal = 'split_cifar100'
        if get_config_attr(cfg, 'CIFAR.MINI_IMAGENET', default=0):
            goal = 'split_mini_imagenet'



    if hasattr(cfg,'MNIST'):

        num_of_datasets = 1 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int)
        num_of_classes = 10 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.CLASS_NUM', totype=int)
        base_model = mnist_simple_net_400width_classlearning_1024input_10cls_1ds(num_of_datasets=num_of_datasets,
                                                                                 num_of_classes=num_of_classes,
                                                                                 task_incremental=task_incremental)

        base_model.cfg = cfg
    elif hasattr(cfg, 'CIFAR'):
        if goal == 'split_cifar':
            num_of_datasets = 1 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int)
            num_of_classes = 10 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.CLASS_NUM', totype=int)
        elif goal == 'split_cifar100':
            num_of_datasets = 1 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int)
            num_of_classes = 100 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.CLASS_NUM', totype=int)
        elif goal == 'split_mini_imagenet':
            num_of_datasets = 1 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int)
            num_of_classes = 100 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.CLASS_NUM', totype=int)


        base_model = ResNetClassifier(cfg, depth='18', mlp=1, ignore_index=-100, num_of_datasets=num_of_datasets,
                                      num_of_classes=num_of_classes, task_incremental=task_incremental, goal=goal)
        base_model.cfg = cfg
    else:
        base_model = ResNetClassifier(cfg)

    device = torch.device(cfg.MODEL.DEVICE)
    base_model.to(device)
    if cfg.EXTERNAL.OPTIMIZER.ADAM:
        optimizer = torch.optim.Adam(
            filter(lambda x: x.requires_grad, base_model.parameters()),
            lr=cfg.SOLVER.BASE_LR, betas=(0.9, 0.999)
        )
    else:
        optimizer = torch.optim.SGD(
            filter(lambda x: x.requires_grad, base_model.parameters()),
            lr=cfg.SOLVER.BASE_LR
        )

    # algorithm specific model wrapper
    x_size = 3 * 2 * base_model.cfg.EXTERNAL.IMAGE_SIZE ** 2 if goal == 'classification' else \
        3 * base_model.cfg.EXTERNAL.IMAGE_SIZE ** 2
    if goal == 'split_mnist' or goal == 'permute_mnist' or goal == 'rotated_mnist': x_size = 28 * 28
    if goal == 'split_cifar' or goal == 'split_cifar100':  x_size = 3 * 32 * 32
    if goal == 'split_mini_imagenet': x_size = 3 * 84 * 84

    if algo == 'ER':
        model = ExperienceReplay(base_model, optimizer, x_size, base_model.cfg, goal)
    elif algo == 'VERX':
        model = ExperienceEvolveApprox(base_model, optimizer, x_size, base_model.cfg, goal)
    elif algo == 'AGEM':
        model = AGEM(base_model, optimizer, x_size, base_model.cfg, goal)
    elif algo == 'naive':
        model = NaiveWrapper(base_model, optimizer, x_size, base_model.cfg, goal)
    model.to(device)

    use_mixed_precision = cfg.DTYPE == "float16"
    arguments = {"iteration": 0, "global_step": 0, "epoch": 0}
    output_dir = cfg.OUTPUT_DIR
    writer = None
    epoch_num = 1
    for e in range(epoch_num):
        print("epoch")

        arguments['iteration'] = 0
        epoch = arguments['epoch']
        if goal == 'split_mnist' or goal == 'permute_mnist' or goal == 'rotated_mnist':
            ocl_train_mnist(model, optimizer, None, device, arguments, writer, epoch, goal, tune=tune)
        elif goal == 'split_cifar' or goal == 'split_cifar100' or goal == 'split_mini_imagenet':
            ocl_train_cifar(model, optimizer, None, device, arguments, writer, epoch, goal, tune=tune)
        else:
            raise NotImplementedError
        arguments['epoch'] += 1

        with open(os.path.join(output_dir, 'model.bin'),'wb') as wf:
            torch.save(model.state_dict(), wf)
        # else:
        #     break
        if is_ocl and hasattr(model, 'dump_reservoir') and args.dump_reservoir:
            model.dump_reservoir(os.path.join(cfg.OUTPUT_DIR, 'mem_dump.pkl'), verbose=args.dump_reservoir_verbose)
    return model
コード例 #8
0
ファイル: agem.py プロジェクト: INK-USC/GMED
    def __init__(self, base, optimizer, input_size, cfg, goal):
        super(AGEM, self).__init__(base, optimizer, input_size, cfg, goal)

        # self.grads = self.grads.cuda()
        self.violation_count = 0
        self.agem_k = get_config_attr(cfg, 'EXTERNAL.OCL.AGEM_K', default=256)
コード例 #9
0
ファイル: er.py プロジェクト: INK-USC/GMED
    def __init__(self, base, optimizer, input_size, cfg, goal):
        super().__init__(base, optimizer, input_size, cfg, goal)
        self.net = base
        self.optimizer = optimizer
        self.mem_limit = cfg.EXTERNAL.REPLAY.MEM_LIMIT
        self.mem_bs = cfg.EXTERNAL.REPLAY.MEM_BS
        self.input_size = input_size
        self.reservoir, self.example_seen = None, None
        self.reset_mem()
        self.mem_occupied = {}

        self.seen_tasks = []
        self.balanced = False
        self.policy = get_config_attr(cfg,
                                      'EXTERNAL.OCL.POLICY',
                                      default='reservoir',
                                      totype=str)
        self.mir_k = get_config_attr(cfg,
                                     'EXTERNAL.REPLAY.MIR_K',
                                     default=10,
                                     totype=int)
        self.mir = get_config_attr(cfg,
                                   'EXTERNAL.OCL.MIR',
                                   default=0,
                                   totype=int)

        self.mir_agg = get_config_attr(cfg,
                                       'EXTERNAL.OCL.MIR_AGG',
                                       default='avg',
                                       totype=str)

        self.concat_replay = get_config_attr(cfg,
                                             'EXTERNAL.OCL.CONCAT',
                                             default=0,
                                             totype=int)
        self.separate_replay = get_config_attr(cfg,
                                               'EXTERNAL.OCL.SEPARATE',
                                               default=0,
                                               totype=int)
        self.mem_augment = get_config_attr(cfg,
                                           'EXTERNAL.OCL.MEM_AUG',
                                           default=0,
                                           totype=int)
        self.legacy_aug = get_config_attr(cfg,
                                          'EXTERNAL.OCL.LEGACY_AUG',
                                          default=0,
                                          totype=int)
        self.use_hflip_aug = get_config_attr(cfg,
                                             'EXTERNAL.OCL.USE_HFLIP_AUG',
                                             default=1,
                                             totype=int)
        self.padding_aug = get_config_attr(cfg,
                                           'EXTERNAL.OCL.PADDING_AUG',
                                           default=-1,
                                           totype=int)
        self.rot_aug = get_config_attr(cfg,
                                       'EXTERNAL.OCL.ROT_AUG',
                                       default=-1,
                                       totype=int)

        self.lb_reservoir = get_config_attr(cfg,
                                            'EXTERNAL.OCL.LB_RESERVOIR',
                                            default=0)

        self.cfg = cfg
        self.grad_dims = []

        for param in self.parameters():
            self.grad_dims.append(param.data.numel())
コード例 #10
0
ファイル: er.py プロジェクト: INK-USC/GMED
    def sample_mem_batch(self,
                         device,
                         return_indices=False,
                         k=None,
                         seed=1,
                         mir=False,
                         input_x=None,
                         input_y=None,
                         input_task_ids=None,
                         mir_k=0,
                         skip_task=None,
                         mir_least=False):
        random_state = self.get_random(seed)
        if k is None:
            k = self.mem_bs

        if not self.balanced:
            # reservoir
            n_max = min(self.mem_limit, self.example_seen)
            available_indices = [_ for _ in range(n_max)]
            if skip_task is not None and get_config_attr(
                    self.cfg, 'EXTERNAL.REPLAY.FILTER_SELF', default=0,
                    mute=True):
                available_indices = list(
                    filter(lambda x: self.reservoir['y_extra'][x] != skip_task,
                           available_indices))
            if not available_indices:
                if return_indices:
                    return None, None, None
                else:
                    return None, None, None
            elif len(available_indices) < k:
                indices = np.arange(n_max)
            else:
                indices = random_state.choice(available_indices,
                                              k,
                                              replace=False)
        else:
            available_index = self.get_available_index()
            if len(available_index) == 0:
                if return_indices:
                    return None, None, None
                else:
                    return None, None, None
            elif len(available_index) < k:
                indices = np.array(available_index)
            else:
                indices = random_state.choice(available_index,
                                              k,
                                              replace=False)
        x = self.reservoir['x'][indices]
        x = torch.from_numpy(x).to(device).float()

        y = index_select(self.reservoir['y'], indices,
                         device)  # [  [...], [...] ]
        y_extra = index_select(self.reservoir['y_extra'], indices, device)
        if type(y[0]) not in [list, tuple]:
            y_pad = concat_with_padding(y)
        else:
            y_pad = [torch.stack(_).to(device) for _ in zip(*y)]
        y_extra = concat_with_padding(y_extra)

        if mir:
            x, y_pad, y_extra, indices = self.decide_mir_mem(
                input_x, input_y, input_task_ids, mir_k, x, y, y_extra,
                indices, mir_least)

        if not return_indices:
            return x, y_pad, y_extra
        else:
            return (x, indices), y_pad, y_extra
コード例 #11
0
ファイル: trainer_benchmark.py プロジェクト: INK-USC/GMED
def ocl_train_mnist(model,
                    optimizer,
                    checkpointer,
                    device,
                    arguments,
                    writer,
                    epoch,
                    goal='split',
                    tune=False):
    logger = logging.getLogger("maskrcnn_benchmark.trainer")
    logger.info("Start training @ epoch {:02d}".format(arguments['epoch']))
    model.train()
    cfg = model.cfg
    pbar = tqdm(position=0, desc='GPU: 0')

    num_instances = cfg.MNIST.INSTANCE_NUM

    if goal == 'split_mnist':
        task_num = 5
        loader_func = get_split_mnist_dataloader
    elif goal == 'permute_mnist':
        task_num = 10
        loader_func = get_permute_mnist_dataloader
    elif goal == 'rotated_mnist':
        task_num = 20
        loader_func = get_rotated_mnist_dataloader
    else:
        raise ValueError

    if tune:
        task_num = get_config_attr(cfg,
                                   'EXTERNAL.OCL.TASK_NUM',
                                   totype=int,
                                   default=3)

    num_epoch = get_config_attr(cfg, 'EXTERNAL.EPOCH', totype=int, default=1)
    total_step = task_num * 1000
    base_lr = get_config_attr(cfg, 'SOLVER.BASE_LR', totype=float)
    # whether iid
    iid = not get_config_attr(cfg, 'EXTERNAL.OCL.ACTIVATED', totype=bool)
    do_exp_lr_decay = get_config_attr(cfg, 'EXTERNAL.OCL.EXP_LR_DECAY', 0)

    all_accs = []
    best_avg_accs = []
    step = 0
    for task_id in range(task_num):
        if iid:
            if task_id != 0: break
            data_loaders = [
                loader_func(cfg,
                            'train', [task_id],
                            batch_size=cfg.EXTERNAL.BATCH_SIZE,
                            max_instance=num_instances)
                for task_id in range(task_num)
            ]
            data_loader = DataLoader(IIDDataset(data_loaders),
                                     batch_size=cfg.EXTERNAL.BATCH_SIZE)
            num_instances *= task_num
        else:
            data_loader = loader_func(cfg,
                                      'train', [task_id],
                                      batch_size=cfg.EXTERNAL.BATCH_SIZE,
                                      max_instance=num_instances)

        best_avg_acc = -1
        #model.net.set_task(task_id) # choose the classifier head if the model supports
        for epoch in range(num_epoch):
            seen = 0
            for i, data in enumerate(data_loader):
                if seen >= num_instances: break
                inputs, labels = data
                inputs, labels = (inputs.to(device), labels.to(device))
                task_ids = torch.LongTensor([task_id] * labels.size(0)).to(
                    inputs.device)
                inputs = inputs.flatten(1)
                model.observe(inputs, labels, task_ids=task_ids)
                step += 1
                if do_exp_lr_decay:
                    exp_decay_lr(optimizer, step, total_step, base_lr)

                seen += labels.size(0)
            # run evaluation
            with torch.no_grad():
                if iid:
                    accs, _, avg_acc = inference_mnist(model,
                                                       task_num,
                                                       loader_func,
                                                       device,
                                                       tune=tune)
                else:
                    accs, _, avg_acc = inference_mnist(model,
                                                       task_id + 1,
                                                       loader_func,
                                                       device,
                                                       tune=tune)
            logger.info('Epoch {}\tTask {}\tAcc {}'.format(
                epoch, task_id, avg_acc))
            for i, acc in enumerate(accs):
                logger.info('::Val Task {}\t Acc {}'.format(i, acc))
            all_accs.append(accs)
            if avg_acc > best_avg_acc:
                best_avg_acc = avg_acc
            else:
                break
        best_avg_accs.append(best_avg_acc)
    file_name = 'result.json' if not tune else 'result_tune_k{}.json'.format(
        task_num)
    result_file = open(os.path.join(cfg.OUTPUT_DIR, file_name), 'w')
    json.dump({
        'all_accs': all_accs,
        'avg_acc': avg_acc
    },
              result_file,
              indent=4)
    result_file.close()
コード例 #12
0
ファイル: trainer_benchmark.py プロジェクト: INK-USC/GMED
def ocl_train_cifar(model,
                    optimizer,
                    checkpointer,
                    device,
                    arguments,
                    writer,
                    epoch,
                    goal='split_cifar',
                    tune=False):
    logger = logging.getLogger("maskrcnn_benchmark.trainer")
    logger.info("Start training @ epoch {:02d}".format(arguments['epoch']))
    model.train()
    cfg = model.cfg

    num_epoch = cfg.CIFAR.EPOCH
    if goal == 'split_cifar':
        loader_func = get_split_cifar_dataloader
        total_step = 4750
    elif goal == 'split_cifar100':
        loader_func = get_split_cifar100_dataloader
        total_step = 25000
    else:
        loader_func = get_split_mini_imagenet_dataloader
        total_step = 22500
    max_instance = cfg.CIFAR.INSTANCE_NUM if hasattr(cfg.CIFAR,
                                                     'INSTANCE_NUM') else 1e10
    if not tune:
        task_num = get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int)
    else:
        task_num = get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int)

    do_exp_lr_decay = get_config_attr(cfg, 'EXTERNAL.OCL.EXP_LR_DECAY', 0)
    base_lr = get_config_attr(cfg, 'SOLVER.BASE_LR', totype=float)
    step = 0

    num_epoch = get_config_attr(cfg, 'EXTERNAL.EPOCH', totype=int, default=1)
    all_accs = []
    best_avg_accs = []
    iid = not get_config_attr(cfg, 'EXTERNAL.OCL.ACTIVATED', totype=bool)
    for task_id in range(task_num):
        if iid:
            if task_id != 0: break
            data_loaders = [
                loader_func(cfg,
                            'train', [task_id],
                            batch_size=cfg.EXTERNAL.BATCH_SIZE,
                            max_instance=max_instance)
                for task_id in range(task_num)
            ]
            data_loader = DataLoader(IIDDataset(data_loaders),
                                     batch_size=cfg.EXTERNAL.BATCH_SIZE)
            max_instance *= task_num
        else:
            data_loader = loader_func(cfg,
                                      'train', [task_id],
                                      batch_size=cfg.EXTERNAL.BATCH_SIZE,
                                      max_instance=max_instance)
        pbar = tqdm(position=0, desc='GPU: 0', total=len(data_loader))
        best_avg_acc = -1
        for epoch in range(num_epoch):
            seen = 0
            for i, data in enumerate(data_loader):
                if seen >= max_instance: break
                pbar.update(1)
                inputs, labels = data
                inputs, labels = (inputs.to(device), labels.to(device))
                inputs = inputs.flatten(1)
                task_ids = torch.LongTensor([task_id] * labels.size(0)).to(
                    inputs.device)
                model.observe(inputs, labels, task_ids)
                seen += inputs.size(0)
                if do_exp_lr_decay:
                    exp_decay_lr(optimizer, step, total_step, base_lr)
                step += 1
            # # run evaluation
            with torch.no_grad():
                if iid:
                    accs, _, avg_acc = inference_cifar(model,
                                                       task_num,
                                                       loader_func,
                                                       device,
                                                       goal,
                                                       tune=tune)
                else:
                    accs, _, avg_acc = inference_cifar(model,
                                                       task_id + 1,
                                                       loader_func,
                                                       device,
                                                       goal,
                                                       tune=tune)
            logger.info('Epoch {}\tTask {}\tAcc {}'.format(
                epoch, task_id, avg_acc))
            for i, acc in enumerate(accs):
                logger.info('::Val Task {}\t Acc {}'.format(i, acc))
            all_accs.append(accs)
            if avg_acc > best_avg_acc:
                best_avg_acc = avg_acc
            else:
                break
        best_avg_accs.append(best_avg_acc)
    file_name = 'result.json' if not tune else 'result_tune_k{}.json'.format(
        task_num)
    result_file = open(os.path.join(cfg.OUTPUT_DIR, file_name), 'w')
    json.dump(
        {
            'all_accs': all_accs,
            'avg_acc': avg_acc,
            'best_avg_accs': best_avg_accs
        },
        result_file,
        indent=4)
    result_file.close()
    return best_avg_accs