Exemplo n.º 1
0
def get_dataset(args):
    # Prepare dataloaders
    train_dataset, val_dataset = dataloaders.base.__dict__[args.dataset](
        args.dataroot, args.train_aug)
    if args.n_permutation > 0:
        train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen(
            train_dataset,
            val_dataset,
            args.n_permutation,
            remap_class=not args.no_class_remap)
    else:
        train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(
            train_dataset,
            val_dataset,
            first_split_sz=args.first_split_size,
            other_split_sz=args.other_split_size,
            rand_split=args.rand_split,
            remap_class=not args.no_class_remap)

    return train_dataset_splits, val_dataset_splits, task_output_space
def run(args):
    # Prepare dataloaders
    train_dataset, val_dataset = dataloaders.base.__dict__[args.dataset](args.dataroot, args.train_aug)
    if args.n_permutation > 0:
        train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen(train_dataset, val_dataset,
                                                                             args.n_permutation,
                                                                             remap_class=not args.no_class_remap)
    else:
        train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(train_dataset, val_dataset,
                                                                          first_split_sz=args.first_split_size,
                                                                          other_split_sz=args.other_split_size,
                                                                          rand_split=args.rand_split,
                                                                          remap_class=not args.no_class_remap)

    task_names = sorted(list(task_output_space.keys()), key=int)
    if len(args.eps_val) == 1:
        args.eps_val = [args.eps_val[0]] * len(task_names)
    if len(args.eps_max) == 1:
        args.eps_max = [args.eps_max[0]] * len(task_names)
    if len(args.eps_epoch) == 1:
        args.eps_epoch = [args.eps_epoch[0]] * len(task_names)
    if len(args.kappa_epoch) == 1:
        args.kappa_epoch = [args.kappa_epoch[0]] * len(task_names)
    if len(args.schedule) == 1:
        args.schedule = [args.schedule[0]] * len(task_names)

    # Prepare the Agent (model)
    agent_config = {'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay,
                    'schedule': args.schedule,
                    'model_type': args.model_type, 'model_name': args.model_name,
                    'model_weights': args.model_weights,
                    'out_dim': {'All': args.force_out_dim} if args.force_out_dim > 0 else task_output_space,
                    'optimizer': args.optimizer,
                    'print_freq': args.print_freq, 'gpuid': args.gpuid,
                    'reg_coef': args.reg_coef,
                    'force_out_dim': args.force_out_dim,
                    'clipping': args.clipping,
                    'eps_per_model': args.eps_per_model,
                    'milestones': args.milestones,
                    'dataset_name': args.dataset }
    agent = agents.__dict__[args.agent_type].__dict__[args.agent_name](agent_config)
    print(agent.model)
    print('#parameter of model:', agent.count_parameter())

    # Decide split ordering
    print('Task order:', task_names)
    if args.rand_split_order:
        shuffle(task_names)
        print('Shuffled task order:', task_names)

    acc_table = OrderedDict()
    if args.offline_training:  # Non-incremental learning / offline_training / measure the upper-bound performance
        task_names = ['All']
        train_dataset_all = torch.utils.data.ConcatDataset(train_dataset_splits.values())
        val_dataset_all = torch.utils.data.ConcatDataset(val_dataset_splits.values())
        train_loader = DataLoader(train_dataset_all, batch_size=args.batch_size,
                                  shuffle=True, num_workers=args.workers)
        val_loader = DataLoader(val_dataset_all, batch_size=args.batch_size,
                                shuffle=False, num_workers=args.workers)

        agent.learn_batch(train_loader, val_loader)

        acc_table['All'] = {}
        acc_table['All']['All'] = agent.validation(val_loader)

    else:  # Incremental learning
        # Feed data to agent and evaluate agent's performance
        for i in range(len(task_names)):
            train_name = task_names[i]
            agent.current_task = int(task_names[i])
            print('======================', train_name, '=======================')
            train_loader = DataLoader(train_dataset_splits[train_name], batch_size=args.batch_size,
                                      shuffle=True, num_workers=args.workers)
            val_loader = DataLoader(val_dataset_splits[train_name], batch_size=args.batch_size,
                                    shuffle=False, num_workers=args.workers)

            if args.incremental_class:
                agent.add_valid_output_dim(task_output_space[train_name])

            if args.eps_max:
                agent.eps_scheduler.set_end(args.eps_max[i])

            agent.kappa_scheduler.end = args.kappa_min
            iter_on_batch = len(train_loader)
            agent.kappa_scheduler.calc_coefficient(args.kappa_min-1, args.kappa_epoch[i], iter_on_batch)
            agent.eps_scheduler.calc_coefficient(args.eps_val[i], args.eps_epoch[i], iter_on_batch)
            agent.kappa_scheduler.current, agent.eps_scheduler.current = 1, 0

            if agent.multihead:
                agent.current_head = str(train_name)

            print(f"before batch eps: {agent.eps_scheduler.current}, kappa: {agent.kappa_scheduler.current}")
            agent.learn_batch(train_loader, val_loader)  # Learn
            print(f"after batch eps: {agent.eps_scheduler.current}, kappa: {agent.kappa_scheduler.current}")

            if args.clipping:
                agent.save_params()

            agent.model.print_eps(agent.current_head)
            agent.model.reset_importance()

            # Evaluate
            acc_table[train_name] = OrderedDict()
            for j in range(i+1):
                val_name = task_names[j]
                print('validation split name:', val_name)
                val_data = val_dataset_splits[val_name] if not args.eval_on_train_set else train_dataset_splits[val_name]
                val_loader = DataLoader(val_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
                acc_table[val_name][train_name] = agent.validation(val_loader)
                agent.validation_with_move_weights(val_loader)

            agent.tb.close()

    return acc_table, task_names
def run(args):
    if not os.path.exists('outputs'):
        os.mkdir('outputs')

    # Prepare dataloaders
    train_dataset, val_dataset = dataloaders.base.__dict__[args.dataset](
        args.dataroot, args.train_aug)
    if args.n_permutation > 0:
        train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen(
            train_dataset,
            val_dataset,
            args.n_permutation,
            remap_class=not args.no_class_remap)
    else:
        train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(
            train_dataset,
            val_dataset,
            first_split_sz=args.first_split_size,
            other_split_sz=args.other_split_size,
            rand_split=args.rand_split,
            remap_class=not args.no_class_remap)

    # Prepare the Agent (model)
    agent_config = {
        'lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay,
        'schedule': args.schedule,
        'model_type': args.model_type,
        'model_name': args.model_name,
        'model_weights': args.model_weights,
        'out_dim': {
            'All': args.force_out_dim
        } if args.force_out_dim > 0 else task_output_space,
        'optimizer': args.optimizer,
        'print_freq': args.print_freq,
        'gpuid': args.gpuid,
        'reg_coef': args.reg_coef
    }

    agent = agents.__dict__[args.agent_type].__dict__[args.agent_name](
        agent_config)
    print(agent.model)
    print('#parameter of model:', agent.count_parameter())

    # Decide split ordering
    task_names = sorted(list(task_output_space.keys()), key=int)
    print('Task order:', task_names)
    if args.rand_split_order:
        shuffle(task_names)
        print('Shuffled task order:', task_names)

    acc_table = OrderedDict()
    if args.offline_training:  # Non-incremental learning / offline_training / measure the upper-bound performance
        task_names = ['All']
        train_dataset_all = torch.utils.data.ConcatDataset(
            train_dataset_splits.values())
        val_dataset_all = torch.utils.data.ConcatDataset(
            val_dataset_splits.values())
        train_loader = torch.utils.data.DataLoader(train_dataset_all,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers)
        val_loader = torch.utils.data.DataLoader(val_dataset_all,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers)

        agent.learn_batch(train_loader, val_loader)

        acc_table['All'] = {}
        acc_table['All']['All'] = agent.validation(val_loader)

    else:  # Incremental learning
        # Feed data to agent and evaluate agent's performance
        for i in range(len(task_names)):
            train_name = task_names[i]
            print('======================', train_name,
                  '=======================')
            train_loader = torch.utils.data.DataLoader(
                train_dataset_splits[train_name],
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers)
            val_loader = torch.utils.data.DataLoader(
                val_dataset_splits[train_name],
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers)

            if args.incremental_class:
                agent.add_valid_output_dim(task_output_space[train_name])

            # Learn
            agent.learn_batch(train_loader, val_loader)

            # Evaluate
            acc_table[train_name] = OrderedDict()
            for j in range(i + 1):
                val_name = task_names[j]
                print('validation split name:', val_name)
                val_data = val_dataset_splits[
                    val_name] if not args.eval_on_train_set else train_dataset_splits[
                        val_name]
                val_loader = torch.utils.data.DataLoader(
                    val_data,
                    batch_size=args.batch_size,
                    shuffle=False,
                    num_workers=args.workers)
                acc_table[val_name][train_name] = agent.validation(val_loader)

    return acc_table, task_names
def run(args, rand_seed):
    print('Random seeed', rand_seed)
    if args.benchmark:
        print('benchamrked')
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.manual_seed(rand_seed)
        np.random.seed(rand_seed)
        random.seed(rand_seed)

    # sparse_wts = args.sparse_wt
    # for sparse_wt in sparse_wts:
    if args.single_tasks:
        args.exp_name = f'{args.exp_name}_single_tasks'
    # args.sparse_wt = sparse_wt

    if not os.path.exists('outputs'):
        os.mkdir('outputs')

    # Prepare dataloaders
    train_dataset, val_dataset = dataloaders.base.__dict__[args.dataset](
        args.dataroot, args.train_aug)

    # 5 sequence tasks
    if args.dataset == 'multidataset':
        train_dataset_splits, val_dataset_splits, task_output_space = get_splits(
            train_dataset, val_dataset)

    else:

        if args.n_permutation > 0:
            train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen(
                train_dataset,
                val_dataset,
                args.n_permutation,
                remap_class=not args.no_class_remap)

        else:
            train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(
                train_dataset,
                val_dataset,
                first_split_sz=args.first_split_size,
                other_split_sz=args.other_split_size,
                rand_split=args.rand_split,
                remap_class=not args.no_class_remap)

    # Prepare the Agent (model)
    agent_config = {
        'lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay,
        'schedule': args.schedule,
        'model_type': args.model_type,
        'model_name': args.model_name,
        'model_weights': args.model_weights,
        'out_dim': {
            'All': args.force_out_dim
        } if args.force_out_dim > 0 else task_output_space,
        'optimizer': args.optimizer,
        'print_freq': args.print_freq,
        'gpuid': args.gpuid,
        'reg_coef': args.reg_coef,
        'exp_name': args.exp_name,
        'nuclear_weight': args.nuclear_weight,
        'period': args.period,
        'threshold_trp': args.threshold_trp,
        'sparse_wt': args.sparse_wt,
        'perp_wt': args.perp_wt,
        'reg_type_svd': args.reg_type_svd,
        'energy_sv': args.energy_sv,
        'save_running_stats': args.save_running_stats,
        'e_search': args.e_search,
        'sp_wt_search': args.sp_wt_search,
        'single_tasks': args.single_tasks,
        'prev_sing': args.prev_sing,
        'debug': args.debug,
        'grow_network': args.grow_network,
        'ind_models': args.ind_models,
        'dataset': args.dataset
    }

    if args.ind_models:
        agent_config_lst = {}
        for key_, val in task_output_space.items():

            agent_config_lst[key_] = copy.deepcopy(agent_config)
            agent_config_lst[key_]['out_dim'] = {key_: val}
            agent_config_lst[key_]['ind_models'] = True
            # import pdb; pdb.set_trace()

        agents_lst = {}
        for key_, val in task_output_space.items():
            agents_lst[key_] = agents.__dict__[args.agent_type].__dict__[
                args.agent_name](agent_config_lst[key_])

    else:
        print(args.nuclear_weight, args.threshold_trp)
        agent = agents.__dict__[args.agent_type].__dict__[args.agent_name](
            agent_config)
    # import pdb; pdb.set_trace()
    task_names = sorted(list(task_output_space.keys()), key=int)
    print('Task order:', task_names)

    if args.rand_split_order:
        shuffle(task_names)
        print('Shuffled task order:', task_names)

    acc_table = OrderedDict()
    if args.offline_training:  # Non-incremental learning / offline_training / measure the upper-bound performance
        task_names = ['All']
        train_dataset_all = torch.utils.data.ConcatDataset(
            train_dataset_splits.values())
        val_dataset_all = torch.utils.data.ConcatDataset(
            val_dataset_splits.values())
        train_loader = torch.utils.data.DataLoader(train_dataset_all,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers)

        test_loader = torch.utils.data.DataLoader(val_dataset_all,
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=args.workers)

        agent.learn_batch(train_loader, test_loader)

        acc_table['All'] = {}
        acc_table['All']['All'] = agent.validation(test_loader)

    else:  # Incremental learning
        save_dict = {}
        # adhering to the Advarsarial Continual learning paper
        if args.dataset == 'miniImageNet':
            validation_split = 0.02

        else:
            validation_split = 0.15

        final_accs = OrderedDict()
        if args.loadmodel:
            agent.load_model(args.loadmodel)
            for i in range(len(task_names)):
                train_name = task_names[i]
                test_loader = torch.utils.data.DataLoader(
                    val_dataset_splits[train_name],
                    batch_size=args.batch_size,
                    shuffle=False,
                    num_workers=args.workers)

                final_accs[train_name] = agent.validation(
                    test_loader, train_name).avg
                print(
                    f'/CumAcc/Task{train_name}, {final_accs[train_name]}, {i}')
        else:

            for i in range(len(task_names)):

                train_name = task_names[i]
                if args.ind_models:
                    agent = agents_lst[train_name]

                writer = SummaryWriter(log_dir="runs/" + agent.exp_name)
                # # import pdb; pdb.set_trace()
                #print('Final split for ImageNet', int(np.floor(validation_split * len(train_dataset_splits[train_name]))))
                # split = int(np.floor(validation_split * len(train_dataset_splits[train_name])))
                # train_split, val_split = torch.utils.data.random_split(train_dataset_splits[train_name], [len(train_dataset_splits[train_name]) - split, split])
                # train_dataset_splits[train_name] = train_split
                print('====================== Task Num', i + 1,
                      '=======================')
                print('======================', train_name,
                      '=======================')
                train_loader = torch.utils.data.DataLoader(
                    train_dataset_splits[train_name],
                    batch_size=args.batch_size,
                    shuffle=True,
                    num_workers=args.workers)
                test_loader = torch.utils.data.DataLoader(
                    val_dataset_splits[train_name],
                    batch_size=args.batch_size,
                    shuffle=False,
                    num_workers=args.workers)

                if args.incremental_class:
                    agent.add_valid_output_dim(task_output_space[train_name])

                # Learn
                # import pdb; pdb.set_trace()
                agent.learn_batch(train_loader,
                                  test_loader,
                                  task_name=train_name)

                # if single task skip this step:
                # Evaluate
                if not (args.single_tasks or args.ind_models):
                    final_accs = OrderedDict()
                    acc_table[train_name] = OrderedDict()
                    for j in range(i + 1):
                        # import pdb; pdb.set_trace()
                        val_name = task_names[j]
                        print('validation split name:', val_name)
                        # import pdb; pdb.set_trace()
                        val_data = val_dataset_splits[
                            val_name] if not args.eval_on_train_set else train_dataset_splits[
                                val_name]
                        test_loader = torch.utils.data.DataLoader(
                            val_data,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)
                        acc_table[val_name][train_name] = agent.validation(
                            test_loader, val_name)
                        final_accs[val_name] = acc_table[val_name][
                            train_name].avg
                        print(
                            f'/CumAcc/Task{val_name}, {acc_table[val_name][train_name].avg}, {i}'
                        )
                        writer.add_scalar('/CumAcc/Task' + val_name,
                                          acc_table[val_name][train_name].avg,
                                          i)

                    # writer.add_scalar('/CumLoss/Task' + val_name, loss_table[val_name][train_name].avg, i )
                elif args.single_tasks:
                    val_name = task_names[i]
                    final_accs[train_name] = agent.validation(
                        test_loader, val_name).avg

                else:
                    val_name = task_names[i]
                    final_accs[train_name] = agent.validation(
                        test_loader, val_name)[0].avg

            agent.save_model()

        print(final_accs)
        # collect the channels used, accuracies of individual  task, compression ratio of the final model for the
        # compute the size of the model
        avg_acc = sum(list(final_accs.values())) / len(final_accs)
        # import pdb; pdb.set_trace()
        model_size = agent.mode_comp(
            agent.chann_used
        ) if not args.ind_models else len(agents_lst) * agent.mode_comp()
        save_dict['channels'] = agent.chann_used if not args.ind_models else []
        # import pdb; pdb.set_trace()
        save_dict['acc'] = final_accs
        save_dict['avg_acc'] = avg_acc
        print('Average accuray is', avg_acc)
        print('Model size is', model_size)
        save_dict['all_rank'] = agent.all_rank if not args.ind_models else []
        save_dict['model_size'] = model_size

    return save_dict, task_names
Exemplo n.º 5
0
def run(args):
    if not os.path.exists('outputs'):
        os.mkdir('outputs')

    # Prepare dataloaders
    train_dataset, val_dataset = dataloaders.base.__dict__[args.dataset](args.dataroot, args.train_aug)
    if args.n_permutation>0:
        train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen(train_dataset, val_dataset,
                                                                             args.n_permutation,
                                                                             remap_class=not args.no_class_remap)
    else:
        train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(train_dataset, val_dataset,
                                                                          first_split_sz=args.first_split_size,
                                                                          other_split_sz=args.other_split_size,
                                                                          rand_split=args.rand_split,
                                                                          remap_class=not args.no_class_remap)

    # Prepare the Agent (model)
    agent_config = {'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay,'schedule': args.schedule,
                    'model_type':args.model_type, 'model_name': args.model_name, 'model_weights':args.model_weights,
                    'out_dim':{'All':args.force_out_dim} if args.force_out_dim > 0 else task_output_space,
                    'optimizer':args.optimizer,
                    'print_freq':args.print_freq, 'gpuid': args.gpuid,
                    'reg_coef':args.reg_coef, 'exp_name' : args.exp_name, 'warmup':args.warm_up, 'nesterov':args.nesterov, 'run_num' :args.run_num, 'freeze_core':args.freeze_core, 'reset_opt':args.reset_opt, 'noise_':args.noise_, 'add_extra_last':args.add_extra_last,
                    'batch_size':args.batch_size, 'reg': args.reg }
                    
    agent = agents.__dict__[args.agent_type].__dict__[args.agent_name](agent_config)
    print(agent.model)
    print('#parameter of model:',agent.count_parameter())
    

    # Decide split ordering
    task_names = sorted(list(task_output_space.keys()), key=int)
    print('Task order:',task_names)
    #import pdb; pdb.set_trace()
    if args.rand_split_order:
        shuffle(task_names)
        print('Shuffled task order:', task_names)

    acc_table = OrderedDict()
    loss_table = OrderedDict()
    if args.offline_training:  # Non-incremental learning / offline_training / measure the upper-bound performance
        task_names = ['All']
        train_dataset_all = torch.utils.data.ConcatDataset(train_dataset_splits.values())
        val_dataset_all = torch.utils.data.ConcatDataset(val_dataset_splits.values())
        train_loader = torch.utils.data.DataLoader(train_dataset_all,
                                                   batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
        val_loader = torch.utils.data.DataLoader(val_dataset_all,
                                                 batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
        
        epochs = args.epochs[-1]
        agent.learn_batch(train_loader, val_loader, [0, epochs])

        acc_table['All'] = {}
        loss_table['All'] = {}
        acc_table['All']['All'], loss_table['All']['All'] = agent.validation(val_loader)

    else:  # Incremental learning
        # Feed data to agent and evaluate agent's performance
        for i in range(1):
            train_name = task_names[i]
            print('======================',train_name,'=======================')
            train_loader = torch.utils.data.DataLoader(train_dataset_splits[train_name],
                                                        batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
            val_loader = torch.utils.data.DataLoader(val_dataset_splits[train_name],
                                                      batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

            if args.incremental_class:
                agent.add_valid_output_dim(task_output_space[train_name])

            # Learn
            epochs = args.epochs[i] if len(args.epochs) - 1  else args.epochs[0]
            # split the epochs into multiple sub epochs
            # to perform validation after every 10 in between if epochs are 80 --> 20, 30, 40, 50, 40 , 80 
            # helps us in better understanding how the degradation happens
            for epoch_10 in range(int(epochs / args.old_val_freq)):
                    agent.learn_batch(train_loader, val_loader, epochs=[epoch_10 * args.old_val_freq, (epoch_10 + 1)* args.old_val_freq], task_n=train_name)
                    # Evaluate
                    acc_table[train_name] = OrderedDict()
                    loss_table[train_name] = OrderedDict()
                    writer = SummaryWriter(log_dir="runs/" + agent.exp_name)
                    for j in range(i+1):
                        if i != j:
                            continue
                        val_name = task_names[j]
                        print('validation split name:', val_name)
                        val_data = val_dataset_splits[val_name] if not args.eval_on_train_set else train_dataset_splits[val_name]
                        val_loader = torch.utils.data.DataLoader(val_data,
                                                                batch_size=args.batch_size, shuffle=False,
                                                                num_workers=args.workers)
                        acc_table[val_name][train_name], loss_table[val_name][train_name] = agent.validation(val_loader, val_name)
            
                        # tensorboard 
                        # agent.writer.reopen()
                        print('logging for Task  {} while training {}'.format(val_name, train_name))
                        print('logging', int(train_name) + (epoch_10 + 1) * 0.1 )
                
                        writer.add_scalar('Run' + str(args.run_num) +  '/CumAcc/Task' + val_name, acc_table[val_name][train_name].avg, float(int(train_name)) * 100 + (epoch_10 + 1) * args.old_val_freq)
                        writer.add_scalar('Run' + str(args.run_num) +  '/CumLoss/Task' + val_name, loss_table[val_name][train_name].avg, int(train_name) * 100 + (epoch_10 + 1)* args.old_val_freq )
                        writer.close()
            # if i == 1:
                #after the first task freeze some weights:

        # def funcname(self, parameter_list):
        #         npimg = img.numpy().transpose((1,2,0))
        #         min_val = np.min(npimg, keepdims =True)
        #         print('min',min_val)
        #         max_val = np.max(npimg, keepdims =True)
        #         print('max',max_val)
        #         inp = (npimg-min_val)/(max_val-min_val)
        #         # inp = npimg
        #         plt.imshow(inp)
        #     pass
    return acc_table, task_names
Exemplo n.º 6
0
def run(args):
    if not os.path.exists('outputs'):
        os.mkdir('outputs')

    # Prepare dataloaders
    # train_dataset, val_dataset = dataloaders.base.__dict__[args.dataset](args.dataroot, args.train_aug)
    train_dataset, val_dataset = factory('dataloaders', 'base',
                                         args.dataset)(args.dataroot,
                                                       args.train_aug)
    if args.n_permutation > 0:
        train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen(
            train_dataset,
            val_dataset,
            args.n_permutation,
            remap_class=not args.no_class_remap)
    else:
        train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(
            train_dataset,
            val_dataset,
            first_split_sz=args.first_split_size,
            other_split_sz=args.other_split_size,
            rand_split=args.rand_split,
            remap_class=not args.no_class_remap)

    # Prepare the Agent (model)
    dataset_name = args.dataset + \
        '_{}'.format(args.first_split_size) + \
        '_{}'.format(args.other_split_size)
    agent_config = {
        'model_lr': args.model_lr,
        'momentum': args.momentum,
        'model_weight_decay': args.model_weight_decay,
        'schedule': args.schedule,
        'model_type': args.model_type,
        'model_name': args.model_name,
        'model_weights': args.model_weights,
        'out_dim': {
            'All': args.force_out_dim
        } if args.force_out_dim > 0 else task_output_space,
        'model_optimizer': args.model_optimizer,
        'print_freq': args.print_freq,
        'gpu': True if args.gpuid[0] >= 0 else False,
        'with_head': args.with_head,
        'reset_model_opt': args.reset_model_opt,
        'reg_coef': args.reg_coef,
        'head_lr': args.head_lr,
        'svd_lr': args.svd_lr,
        'bn_lr': args.bn_lr,
        'svd_thres': args.svd_thres,
        'gamma': args.gamma,
        'dataset_name': dataset_name
    }

    # agent = agents.__dict__[args.agent_type].__dict__[args.agent_name](agent_config)
    agent = factory('svd_agent', args.agent_type,
                    args.agent_name)(agent_config)

    # Decide split ordering
    task_names = sorted(list(task_output_space.keys()), key=int)
    print('Task order:', task_names)
    if args.rand_split_order:
        shuffle(task_names)
        print('Shuffled task order:', task_names)

    # task_names = ['2', '1', '3', '4', '5']
    acc_table = OrderedDict()
    acc_table_train = OrderedDict()
    if args.offline_training:  # Non-incremental learning / offline_training / measure the upper-bound performance
        task_names = ['All']
        train_dataset_all = torch.utils.data.ConcatDataset(
            train_dataset_splits.values())
        val_dataset_all = torch.utils.data.ConcatDataset(
            val_dataset_splits.values())
        train_loader = torch.utils.data.DataLoader(train_dataset_all,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers)
        val_loader = torch.utils.data.DataLoader(val_dataset_all,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers)

        agent.learn_batch(train_loader, val_loader)

        acc_table['All'] = {}
        acc_table['All']['All'] = agent.validation(val_loader)

    else:  # Incremental learning
        # Feed data to agent and evaluate agent's performance
        for i in range(len(task_names)):
            train_name = task_names[i]
            print('======================', train_name,
                  '=======================')
            train_loader = torch.utils.data.DataLoader(
                train_dataset_splits[train_name],
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers)
            val_loader = torch.utils.data.DataLoader(
                val_dataset_splits[train_name],
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers)

            if args.incremental_class:
                agent.add_valid_output_dim(task_output_space[train_name])

            # Learn
            agent.train_task(train_loader, val_loader)
            torch.cuda.empty_cache()
            # Evaluate
            acc_table[train_name] = OrderedDict()
            acc_table_train[train_name] = OrderedDict()
            for j in range(i + 1):
                val_name = task_names[j]

                print('validation split name:', val_name)
                val_data = val_dataset_splits[
                    val_name] if not args.eval_on_train_set else train_dataset_splits[
                        val_name]
                val_loader = torch.utils.data.DataLoader(
                    val_data,
                    batch_size=args.batch_size,
                    shuffle=False,
                    num_workers=args.workers)
                acc_table[val_name][train_name] = agent.validation(val_loader)

                print("**************************************************")
                print('training split name:', val_name)
                train_data = train_dataset_splits[
                    val_name] if not args.eval_on_train_set else train_dataset_splits[
                        val_name]
                train_loader = torch.utils.data.DataLoader(
                    train_data,
                    batch_size=args.batch_size,
                    shuffle=False,
                    num_workers=args.workers)
                acc_table_train[val_name][train_name] = agent.validation(
                    train_loader)
                print("**************************************************")

    return acc_table, task_names
Exemplo n.º 7
0
def prepare_dataloaders(args):
    # Prepare dataloaders
    Dataset = dataloaders.base.__dict__[args.dataset]

    # SPLIT CUB
    if args.is_split_cub:
        print("running split -------------")
        from dataloaders.cub import CUB
        Dataset = CUB
        if args.train_aug:
            print("train aug not supported for cub")
            return
        train_dataset, val_dataset = Dataset(args.dataroot)
        train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(
            train_dataset,
            val_dataset,
            first_split_sz=args.first_split_size,
            other_split_sz=args.other_split_size,
            rand_split=args.rand_split,
            remap_class=not args.no_class_remap)
        n_tasks = len(task_output_space.items())
    # Permuted MNIST
    elif args.n_permutation > 0:
        # TODO : CHECK subset_size
        train_dataset, val_dataset = Dataset(args.dataroot,
                                             args.train_aug,
                                             angle=0,
                                             subset_size=args.subset_size)
        print("Working with permuatations :) ")
        train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen(
            train_dataset,
            val_dataset,
            args.n_permutation,
            remap_class=not args.no_class_remap)
        n_tasks = args.n_permutation
    # Rotated MNIST
    elif args.n_rotate > 0 or len(args.rotations) > 0:
        # TODO : Check subset size
        train_dataset_splits, val_dataset_splits, task_output_space = RotatedGen(
            Dataset=Dataset,
            dataroot=args.dataroot,
            train_aug=args.train_aug,
            n_rotate=args.n_rotate,
            rotate_step=args.rotate_step,
            remap_class=not args.no_class_remap,
            rotations=args.rotations,
            subset_size=args.subset_size)
        n_tasks = len(task_output_space.items())

    # Split MNIST
    else:
        print("running split -------------")
        # TODO : Check subset size
        train_dataset, val_dataset = Dataset(args.dataroot,
                                             args.train_aug,
                                             angle=0,
                                             subset_size=args.subset_size)
        train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(
            train_dataset,
            val_dataset,
            first_split_sz=args.first_split_size,
            other_split_sz=args.other_split_size,
            rand_split=args.rand_split,
            remap_class=not args.no_class_remap)
        n_tasks = len(task_output_space.items())

    print(f"task_output_space {task_output_space}")

    return task_output_space, n_tasks, train_dataset_splits, val_dataset_splits
Exemplo n.º 8
0
    from tqdm import tqdm

    dataroot = "/tmp/datasets"
    first_split_size = 5
    other_split_size = 5
    rand_split = False
    no_class_remap = False

    # from dataloaders.base import MNIST
    # Dataset = MNIST
    Dataset = CUB
    train_dataset, val_dataset = Dataset(dataroot)
    train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(
        train_dataset,
        val_dataset,
        first_split_sz=first_split_size,
        other_split_sz=other_split_size,
        rand_split=rand_split,
        remap_class=no_class_remap)
    # __get__ returns img, target, self.name through AppendName

    n_tasks = len(task_output_space.items())
    task_names = sorted(list(task_output_space.keys()), key=int)

    batch_size = 32
    workers = 0

    for i in tqdm(range(len(task_names)), "task"):
        task_name = task_names[i]
        print('======================', task_name, '=======================')
        train_loader = torch.utils.data.DataLoader(