Example #1
0
    from approaches import ucb as approach

# Args -- Network
if args.experiment == 'mnist2' or args.experiment == 'pmnist' or args.experiment == 'mnist5':
    from networks import mlp_ucb as network
else:
    from networks import resnet_ucb as network

########################################################################################################################
print()
print("Starting this run on :")
print(datetime.now().strftime("%Y-%m-%d %H:%M"))

# Load
print('Load data...')
data, taskcla, inputsize = dataloader.get(data_path=args.data_path,
                                          seed=args.seed)
print('Input size =', inputsize, '\nTask info =', taskcla)
args.num_tasks = len(taskcla)
args.inputsize, args.taskcla = inputsize, taskcla

# Inits
print('Inits...')
model = network.Net(args).to(args.device)

print('-' * 100)
appr = approach.Appr(model, args=args)
print('-' * 100)

# args.output=os.path.join(args.results_path, datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))
print('-' * 100)
Example #2
0
        from networks import resnet101 as network
    elif args.approach == 'res50':
        from networks import resnet50 as network
    elif args.approach == 'joint':
        #net = models.resnet50()
        #print('torch model')
        from networks import resnet50 as network
        print('my model')
    else:
        from networks import resnet50 as network

########################################################################################################################

# Load
print('Load data...')
data, taskcla, inputsize = dataloader.get(seed=args.seed)
#print('*****data*****',data,type(data));sys.exit(0)
print('Input size =', inputsize, '\nTask info =', taskcla)

# Inits
print('Inits...')
net = network.Net(inputsize, taskcla).cuda()
#from torchsummary import summary
#summary(net.cuda(), inputsize)#summary
utils.print_model_report(net)

appr = approach.Appr(net, nepochs=args.nepochs, lr=args.lr, args=args)
print(appr.criterion)
utils.print_optimizer_config(appr.optimizer)
print('-' * 100)
Example #3
0
        from networks import alexnet_hat_test as network
    elif args.approach=='kan'\
            or 'by-name' in args.note  or 'all-zero' in args.note  \
            or 'all-one' in args.note:
        from networks import AlexnetKan as network
    elif 'auto' in args.note:
        from networks import AlexnetMixTaskKan as network
    else:
        from networks import alexnet as network

########################################################################################################################

# Load
print('Load data...')
if 'sentiment' in args.experiment:
    data, taskcla, inputsize, voc_size, weights_matrix = dataloader.get(
        seed=args.seed, args=args)
else:
    data, taskcla, inputsize = dataloader.get(seed=args.seed, args=args)
print('Input size =', inputsize, '\nTask info =', taskcla)

print('taskcla: ', len(taskcla))
if 'nofemnist' in args.note:
    c = 0
    taskcla_new = []
    data_new = []
    for i, (t, ncla) in enumerate(taskcla):
        if 'fe-mnist' in data[t]['name']:
            continue
        else:
            taskcla_new.append((c, taskcla[i][1]))
            data_new.append(data[t])
Example #4
0
def main():
    tstart = time.time()

    parser = argparse.ArgumentParser(description='BLIP Image Classification')

    # Data parameters
    parser.add_argument('--seed',
                        default=0,
                        type=int,
                        help='(default=%(default)d)')
    parser.add_argument('--device', default='cuda:0', type=str, help='gpu id')
    parser.add_argument('--experiment',
                        default='mnist5',
                        type=str,
                        help='experiment dataset',
                        required=True)
    parser.add_argument('--data_path',
                        default='../data/',
                        type=str,
                        help='gpu id')

    # Training parameters
    parser.add_argument('--approach',
                        default='blip',
                        type=str,
                        help='continual learning approach')
    parser.add_argument('--output', default='', type=str, help='')
    parser.add_argument('--checkpoint_dir',
                        default='../checkpoints/',
                        type=str,
                        help='')
    parser.add_argument('--nepochs', default=200, type=int, help='')
    parser.add_argument('--sbatch', default=64, type=int, help='')
    parser.add_argument('--lr', default=0.05, type=float, help='')
    parser.add_argument('--momentum', default=0, type=float, help='')
    parser.add_argument('--weight-decay', default=0.0, type=float, help='')
    parser.add_argument('--resume', default='no', type=str, help='resume?')
    parser.add_argument('--sti', default=0, type=int, help='starting task?')

    # Model parameters
    parser.add_argument('--ndim',
                        default=1200,
                        type=int,
                        help='hidden dimension for 2 layer MLP')
    parser.add_argument('--mul',
                        default=1.0,
                        type=float,
                        help='multiplier of model width')

    # BLIP specific parameters
    parser.add_argument('--max-bit',
                        default=20,
                        type=int,
                        help='maximum number of bits for each parameter')
    parser.add_argument('--F-prior',
                        default=1e-15,
                        type=float,
                        help='scaling factor of F_prior')

    args = parser.parse_args()
    utils.print_arguments(args)

    #####################################################################################

    # Seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    print('Using device:', args.device)
    checkpoint = utils.make_directories(args)
    args.checkpoint = checkpoint
    print()

    # Args -- Experiment
    if args.experiment == 'mnist2':
        from dataloaders import mnist2 as dataloader
    elif args.experiment == 'mnist5':
        from dataloaders import mnist5 as dataloader
    elif args.experiment == 'pmnist':
        from dataloaders import pmnist as dataloader
    elif args.experiment == 'cifar':
        from dataloaders import cifar as dataloader
    elif args.experiment == 'mixture5':
        from dataloaders import mixture5 as dataloader
    else:
        raise NotImplementedError('dataset currently not implemented')

    # Args -- Approach
    if args.approach == 'blip':
        from approaches import blip as approach
    else:
        raise NotImplementedError('approach currently not implemented')

    # Args -- Network
    if args.experiment == 'mnist2' or args.experiment == 'pmnist' or args.experiment == 'mnist5':
        from networks import q_mlp as network
    else:
        from networks import q_alexnet as network

    ########################################################################################
    print()
    print("Starting this run on :")
    print(datetime.now().strftime("%Y-%m-%d %H:%M"))

    # Load
    print('Load data...')
    data, taskcla, inputsize = dataloader.get(data_path=args.data_path,
                                              seed=args.seed)
    print('Input size =', inputsize, '\nTask info =', taskcla)
    args.num_tasks = len(taskcla)
    args.inputsize, args.taskcla = inputsize, taskcla

    # Inits
    print('Inits...')
    model = network.Net(args).to(args.device)

    print('-' * 100)
    appr = approach.Appr(model, args=args)
    print('-' * 100)

    if args.resume == 'yes':
        checkpoint = torch.load(
            os.path.join(args.checkpoint, 'model_{}.pth.tar'.format(args.sti)))
        model.load_state_dict(checkpoint['model_state_dict'])
        model = model.to(device=args.device)
    else:
        args.sti = 0

    # Loop tasks
    acc = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)
    lss = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)
    num_task = len(taskcla)
    for t, ncla in taskcla[args.sti:]:

        print('*' * 100)
        print('Task {:2d} ({:s})'.format(t, data[t]['name']))
        print('*' * 100)

        # Get data
        xtrain = data[t]['train']['x'].to(args.device)
        ytrain = data[t]['train']['y'].to(args.device)
        xvalid = data[t]['valid']['x'].to(args.device)
        yvalid = data[t]['valid']['y'].to(args.device)
        task = t

        # Train
        appr.train(task, xtrain, ytrain, xvalid, yvalid)
        print('-' * 100)

        # BLIP specifics post processing
        estimate_fisher(task, args.device, model, xtrain, ytrain)
        for m in model.features.modules():
            if isinstance(m, Linear_Q) or isinstance(m, Conv2d_Q):
                # update bits according to information gain
                m.update_bits(task=task, C=0.5 / math.log(2))
                # do quantization
                m.sync_weight()
                # update Fisher in the buffer
                m.update_fisher(task=task)

        # save the model after the update
        appr.save_model(task)
        # Test
        for u in range(t + 1):
            xtest = data[u]['test']['x'].to(args.device)
            ytest = data[u]['test']['y'].to(args.device)
            test_loss, test_acc = appr.eval(u, xtest, ytest, debug=True)
            print(
                '>>> Test on task {:2d} - {:15s}: loss={:.3f}, acc={:5.3f}% <<<'
                .format(u, data[u]['name'], test_loss, 100 * test_acc))
            acc[t, u] = test_acc
            lss[t, u] = test_loss

        utils.used_capacity(model, args.max_bit)

        # Save
        print('Save at ' + args.checkpoint)
        np.savetxt(
            os.path.join(
                args.checkpoint,
                '{}_{}_{}.txt'.format(args.experiment, args.approach,
                                      args.seed)), acc, '%.5f')

    utils.print_log_acc_bwt(args, acc, lss)
    print('[Elapsed time = {:.1f} h]'.format(
        (time.time() - tstart) / (60 * 60)))
Example #5
0
        from networks import alexnet_hat as network
    elif args.approach == 'progressive':
        from networks import alexnet_progressive as network
    elif args.approach == 'pathnet':
        from networks import alexnet_pathnet as network
    elif args.approach == 'hat-test':
        from networks import alexnet_hat_test as network
    else:
        from networks import alexnet as network

########################################################################################################################

# Load
print('Load data...')
data, taskcla, inputsize = dataloader.get(seed=args.seed,
                                          load_from=args.load_from,
                                          shuffle_cl=args.shuffle)
print('Input size =', inputsize, '\nTask info =', taskcla)

# Inits
print('Inits...')
net = network.Net(inputsize, taskcla)
net.to(device)
utils.print_model_report(net)

appr = approach.Appr(net,
                     nepochs=args.nepochs,
                     lr=args.lr,
                     args=args,
                     device=device,
                     lr_factor=args.lr_fact)
Example #6
0
# elif args.experiment =='fmnist5_singlehead':
#     from networks import mlp_ucb_mixture as network
elif args.experiment =='mnist2_singlehead' or args.experiment == 'mnist2_singlehead_ver1' or args.experiment =='mnist_5_singlehead' or args.experiment =='notmnist_singlehead':
	from networks import mlp_ucb_singlehead as network
else:
	from networks import resnet_ucb as network


########################################################################################################################
print()
print("Starting this run on :")
print(datetime.now().strftime("%Y-%m-%d %H:%M"))

# Load
print('Load data...')
data,taskcla,inputsize=dataloader.get(data_path=args.data_path, seed=args.seed)
# data,taskcla,inputsize=pickle.load(open('mnist5_singlehead','rb'))
pickle.dump(dataloader.get(data_path=args.data_path, seed=args.seed) , open("data_cifar", "wb"))
print('Input size =',inputsize,'\nTask info =',taskcla)
args.num_tasks=len(taskcla)
args.inputsize, args.taskcla = inputsize, taskcla
#
# t = data[0]
# t_train = t['train']['x']
# print(type(t_train), t_train.shape)
# t_y= t['train']['y']
# Inits
print('Inits...')
model=network.Net(args).to(args.device)

Example #7
0
def main(seed=0,
         experiment='',
         approach='',
         output='',
         name='',
         nepochs=200,
         lr=0.05,
         weight_init=None,
         test_mode=None,
         log_path=None,
         **parameters):
    '''Trains an experiment given the current settings.

    Args:
        seed (int): Random seed
        experiment (str): Name of the experiment to load - choices: ['mnist2','pmnist','cifar','mixture']
        approach (str): Approach to take to training the experiment - choices: ['random','sgd','sgd-frozen','lwf','lfl','ewc','imm-mean','progressive','pathnet','imm-mode','sgd-restart','joint','hat','hat-test']
        output (str): Path to store the output under
        name (str): Additional experiment name for grid search
        nepochs (int): Number of epochs to iterate through
        lr (float): Learning Rate to apply 
        weight_init (str): String that defines how the weights are initialized - it can be splitted (with `:`) between convolution (first) and Linear (second) layers. Options: ["xavier", "uniform", "normal", "ones", "zeros", "kaiming"]
        test_mode (int): Defines how many tasks to iterate through
        log_path (str): Path to store detailed logs
        parameter (str): Approach dependent parameters
    '''
    # check the output path
    if output == '':
        output = '../res/' + experiment + '_' + approach + '_' + str(seed) + (
            ("_" + name) if len(name) > 0 else "") + '.txt'
    print('=' * 100)
    print('Arguments =')
    #
    args = {
        **parameters, "seed": seed,
        "experiment": experiment,
        "approach": approach,
        "output": output,
        "nepochs": nepochs,
        "lr": lr,
        "weight_init": weight_init
    }
    for arg in args:
        print("\t{:15}: {}".format(arg, args[arg]))
    print('=' * 100)

    ########################################################################################################################

    # Seed
    np.random.seed(seed)
    torch.manual_seed(seed)

    # check if cuda available
    if torch.cuda.is_available(): torch.cuda.manual_seed(seed)
    else:
        print('[CUDA unavailable]')
        sys.exit()

    # Args -- Experiment
    if experiment == 'mnist2':
        from dataloaders import mnist2 as dataloader
    elif experiment == 'pmnist':
        from dataloaders import pmnist as dataloader
    elif experiment == 'cifar':
        from dataloaders import cifar as dataloader
    elif experiment == 'mixture':
        from dataloaders import mixture as dataloader

    # Args -- Approach
    if approach == 'random':
        from approaches import random as appr
    elif approach == 'sgd':
        from approaches import sgd as appr
    elif approach == 'sgd-restart':
        from approaches import sgd_restart as appr
    elif approach == 'sgd-frozen':
        from approaches import sgd_frozen as appr
    elif approach == 'lwf':
        from approaches import lwf as appr
    elif approach == 'lfl':
        from approaches import lfl as appr
    elif approach == 'ewc':
        from approaches import ewc as appr
    elif approach == 'imm-mean':
        from approaches import imm_mean as appr
    elif approach == 'imm-mode':
        from approaches import imm_mode as appr
    elif approach == 'progressive':
        from approaches import progressive as appr
    elif approach == 'pathnet':
        from approaches import pathnet as appr
    elif approach == 'hat-test':
        from approaches import hat_test as approach
    elif approach == 'hat':
        from approaches import hat as appr
    elif approach == 'joint':
        from approaches import joint as appr
    elif approach == 'dwa':
        from approaches import dwa as appr

    # Args -- Network
    if experiment in ['mnist2', 'pmnist']:
        if approach in ['hat', 'hat-test']:
            from networks import mlp_hat as network
        elif approach == 'dwa':
            from networks import mlp_dwa as network
        else:
            from networks import mlp as network
    else:
        if approach == 'lfl':
            from networks import alexnet_lfl as network
        elif approach == 'hat':
            from networks import alexnet_hat as network
        elif approach == 'progressive':
            from networks import alexnet_progressive as network
        elif approach == 'pathnet':
            from networks import alexnet_pathnet as network
        elif approach == 'hat-test':
            from networks import alexnet_hat_test as network
        elif approach == 'dwa':
            from networks import alexnet_dwa as network
        else:
            from networks import alexnet as network

    ########################################################################################################################

    # Load
    print('Load data...')
    data, taskcla, inputsize = dataloader.get(seed=seed)
    print('Input size =', inputsize, '\nTask info =', taskcla)

    # Init the network and put on gpu
    print('Inits...')
    # handle input parameters for dwa approaches
    if approach == "dwa":
        params = {}
        for key in parameters:
            if key in dwa_net_params:
                params[key] = parameters[key]
        net = network.Net(inputsize, taskcla, **params).cuda()
    else:
        net = network.Net(inputsize, taskcla).cuda()
    utils.print_model_report(net)

    # setup network weights
    if weight_init is not None:
        # retrieve init data
        inits = weight_init.split(":")
        conv_init = inits[0].split(",")
        conv_bias = conv_init[1] if len(conv_init) > 1 else "zeros"
        conv_init = conv_init[0]
        linear_init = inits[-1].split(",")
        linear_bias = linear_init[1] if len(linear_init) > 1 else "zeros"
        linear_init = linear_init[0]

        init_funcs = {
            "xavier":
            lambda x: torch.nn.init.xavier_uniform_(x, gain=1.0),
            "kaiming":
            lambda x: torch.nn.init.kaiming_normal_(
                x, nonlinearity="relu", mode='fan_in'),
            "normal":
            lambda x: torch.nn.init.normal_(x, mean=0., std=1.),
            "uniform":
            lambda x: torch.nn.init.uniform_(x, a=0., b=1.),
            "ones":
            lambda x: x.data.fill_(1.),
            "zeros":
            lambda x: x.data.fill_(0.)
        }

        print(
            "Init network weights:\n\tlinear weights: {}\n\tlinear bias: {}\n\tconv weights: {}\n\tconv bias: {}"
            .format(linear_init, linear_bias, conv_init, conv_bias))

        # setup init function
        def init_weights(m):
            if type(m) == torch.nn.Linear or type(m) == Linear_dwa:
                init_funcs[linear_init](m.weight)
                init_funcs[linear_bias](m.bias)
            if type(m) == torch.nn.Conv2d or type(m) == Conv2d_dwa:
                init_funcs[conv_init](m.weight)
                init_funcs[conv_bias](m.bias)
            # TODO: check for masks

        # apply to network
        net.apply(init_weights)

    # setup the approach
    params = parameters
    if approach == 'dwa':
        params = {}
        for key in parameters:
            if key not in dwa_net_params:
                params[key] = parameters[key]
    appr = appr.Appr(net, nepochs=nepochs, lr=lr, log_path=log_path, **params)
    print(appr.criterion)
    utils.print_optimizer_config(appr.optimizer)
    print('-' * 100)

    # Loop tasks
    acc = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)
    lss = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)
    i = 0
    for t, ncla in taskcla:
        # check if in test mode and finish after 1 task
        i += 1
        if test_mode is not None and i > test_mode:
            print("INFO: In Test-Mode - breaking after Task {}".format(
                test_mode))
            break

        print('*' * 100)
        print('Task {:2d} ({:s})'.format(t, data[t]['name']))
        print('*' * 100)

        if approach == 'joint':
            # Get data. We do not put it to GPU
            if t == 0:
                xtrain = data[t]['train']['x']
                ytrain = data[t]['train']['y']
                xvalid = data[t]['valid']['x']
                yvalid = data[t]['valid']['y']
                task_t = t * torch.ones(xtrain.size(0)).int()
                task_v = t * torch.ones(xvalid.size(0)).int()
                task = [task_t, task_v]
            else:
                xtrain = torch.cat((xtrain, data[t]['train']['x']))
                ytrain = torch.cat((ytrain, data[t]['train']['y']))
                xvalid = torch.cat((xvalid, data[t]['valid']['x']))
                yvalid = torch.cat((yvalid, data[t]['valid']['y']))
                task_t = torch.cat(
                    (task_t,
                     t * torch.ones(data[t]['train']['y'].size(0)).int()))
                task_v = torch.cat(
                    (task_v,
                     t * torch.ones(data[t]['valid']['y'].size(0)).int()))
                task = [task_t, task_v]
        else:
            # Get data
            xtrain = data[t]['train']['x'].cuda()
            ytrain = data[t]['train']['y'].cuda()
            xvalid = data[t]['valid']['x'].cuda()
            yvalid = data[t]['valid']['y'].cuda()
            task = t

        # Train
        appr.train(task, xtrain, ytrain, xvalid, yvalid)
        print('-' * 100)

        # Free some cache
        print("INFO: Free cuda cache")
        torch.cuda.empty_cache()

        # Test
        for u in range(t + 1):
            xtest = data[u]['test']['x'].cuda()
            ytest = data[u]['test']['y'].cuda()
            test_loss, test_acc, metric_str = appr.eval(u, xtest, ytest)
            print(
                '>>> Test on task {:2d} - {:15s}: loss={:.3f}, acc={:5.1f}%{} <<<'
                .format(u, data[u]['name'], test_loss, 100 * test_acc,
                        metric_str))
            acc[t, u] = test_acc
            lss[t, u] = test_loss

            # check for introspection method (and logs enabled)
            if hasattr(appr, 'introspect') and appr.logs is not None and (
                    t + 1 >= len(taskcla)):
                # randomly select from dataset
                idx = torch.randperm(xtest.size(0))
                xrand = xtest[idx[:10]]
                yrand = ytest[idx[:10]]

                # compute
                out = appr.introspect(u, xrand, yrand)

                # pickle ouptut
                print('Store task {} analytics'.format(data[u]['name']))
                with gzip.open(
                        os.path.join(
                            appr.logpath,
                            os.path.basename(output) +
                            ".task{}_{}.analysis".format(u, data[u]['name'])),
                        'wb') as intro_file:
                    pickle.dump(out, intro_file, pickle.HIGHEST_PROTOCOL)

        # check if result directory exists
        if not os.path.exists(os.path.dirname(output)):
            print("create output dir")
            os.makedirs(os.path.dirname(output))

        # Save
        print('Save at {}'.format(output))
        np.savetxt(output, acc, '%.4f')

    # Done
    print('*' * 100)
    print('Accuracies =')
    for i in range(acc.shape[0]):
        print('\t', end='')
        for j in range(acc.shape[1]):
            print('{:5.1f}% '.format(100 * acc[i, j]), end='')
        print()
    print('*' * 100)
    print('Done!')

    print('[Elapsed time = {:.1f} h]'.format(
        (time.time() - tstart) / (60 * 60)))

    # optionally: store logs
    if hasattr(appr, 'logs'):
        if appr.logs is not None:
            #save task names
            from copy import deepcopy
            appr.logs['task_name'] = {}
            appr.logs['test_acc'] = {}
            appr.logs['test_loss'] = {}
            for t, ncla in taskcla:
                appr.logs['task_name'][t] = deepcopy(data[t]['name'])
                appr.logs['test_acc'][t] = deepcopy(acc[t, :])
                appr.logs['test_loss'][t] = deepcopy(lss[t, :])
            #pickle
            with gzip.open(
                    os.path.join(appr.logpath,
                                 os.path.basename(output) + "_logs.gzip"),
                    'wb') as log_file:
                pickle.dump(appr.logs, log_file, pickle.HIGHEST_PROTOCOL)

    # store the model (full and light versions)
    model_file = os.path.join(appr.logpath,
                              os.path.basename(output) + ".model")
    torch.save(net, model_file)
    model_file = os.path.join(appr.logpath,
                              os.path.basename(output) + ".weights")
    torch.save(net.state_dict(), model_file)