def init_params(model, args): # - reinitialize all parameters according to default initialization model.apply(utils.weight_reset) # - initialize parameters according to chosen custom initialization (if requested) if hasattr(args, 'init_weight') and not args.init_weight == "standard": utils.weight_init(model, strategy="xavier_normal") if hasattr(args, 'init_bias') and not args.init_bias == "standard": utils.bias_init(model, strategy="constant", value=0.01) # - use pre-trained weights (either for full model or just in conv-layers)? if utils.checkattr(args, "pre_convE") and hasattr( model, 'depth') and model.depth > 0: load_name = model.convE.name if ( not hasattr(args, 'convE_ltag') or args.convE_ltag == "none") else "{}-{}".format( model.convE.name, args.convE_ltag) utils.load_checkpoint(model.convE, model_dir=args.m_dir, name=load_name) if utils.checkattr(args, "pre_convD") and hasattr( model, 'convD') and model.depth > 0: utils.load_checkpoint(model.convD, model_dir=args.m_dir) return model ##-------------------------------------------------------------------------------------------------------------------##
def train(model, train_loader, iters, loss_cbs=list(), eval_cbs=list(), save_every=None, m_dir="./store/models", args=None): '''Train a model with a "train_a_batch" method for [iters] iterations on data from [train_loader]. [model] model to optimize [train_loader] <dataloader> for training [model] on [iters] <int> (max) number of iterations (i.e., batches) to train for [loss_cbs] <list> of callback-<functions> to keep track of training progress [eval_cbs] <list> of callback-<functions> to evaluate model on separate data-set''' device = model._device() #jd's change to measure time eval_cbs = list() # Should convolutional layers be frozen? freeze_convE = (utils.checkattr(args, "freeze_convE") and hasattr(args, "depth") and args.depth > 0) # Create progress-bar (with manual control) bar = tqdm.tqdm(total=iters) iteration = epoch = 0 while iteration < iters: epoch += 1 # Loop over all batches of an epoch for batch_idx, (data, y) in enumerate(train_loader): iteration += 1 # Perform training-step on this batch data, y = data.to(device), y.to(device) loss_dict = model.train_a_batch(data, y=y, freeze_convE=freeze_convE) # Fire training-callbacks (for visualization of training-progress) for loss_cb in loss_cbs: if loss_cb is not None: loss_cb(bar, iteration, loss_dict, epoch=epoch) # Fire evaluation-callbacks (to be executed every [eval_log] iterations, as specified within the functions) for eval_cb in eval_cbs: if eval_cb is not None: eval_cb(model, iteration, epoch=epoch) # Break if max-number of iterations is reached if iteration == iters: bar.close() break # Save checkpoint? if (save_every is not None) and (iteration % save_every) == 0: utils.save_checkpoint(model, model_dir=m_dir)
def define_classifier(args, config, device): # -import required model from models.classifier import Classifier # -create model if (hasattr(args, "depth") and args.depth > 0): model = Classifier( image_size=config['size'], image_channels=config['channels'], classes=config['classes'], # -conv-layers conv_type=args.conv_type, depth=args.depth, start_channels=args.channels, reducing_layers=args.rl, num_blocks=args.n_blocks, conv_bn=True if args.conv_bn == "yes" else False, conv_nl=args.conv_nl, global_pooling=checkattr(args, 'gp'), # -fc-layers fc_layers=args.fc_lay, fc_units=args.fc_units, h_dim=args.h_dim, fc_drop=args.fc_drop, fc_bn=True if args.fc_bn == "yes" else False, fc_nl=args.fc_nl, excit_buffer=True, # -training related parameters AGEM=utils.checkattr(args, 'agem')).to(device) else: model = Classifier( image_size=config['size'], image_channels=config['channels'], classes=config['classes'], # -fc-layers fc_layers=args.fc_lay, fc_units=args.fc_units, h_dim=args.h_dim, fc_drop=args.fc_drop, fc_bn=True if args.fc_bn == "yes" else False, fc_nl=args.fc_nl, excit_buffer=True, # -training related parameters AGEM=utils.checkattr(args, 'agem')).to(device) # -return model return model
def get_param_stamp_from_args(args): '''To get param-stamp a bit quicker.''' from define_models import define_autoencoder, define_classifier # -get configurations of experiment config = get_multitask_experiment( name=args.experiment, scenario=args.scenario, tasks=args.tasks, data_dir=args.d_dir, only_config=True, normalize=args.normalize if hasattr(args, "normalize") else False, verbose=False, ) # -get model architectures model = define_autoencoder(args=args, config=config, device='cpu') if checkattr( args, 'feedback') else define_classifier( args=args, config=config, device='cpu') if checkattr(args, 'feedback'): model.lamda_pl = 1. if not hasattr(args, 'pl') else args.pl train_gen = (hasattr(args, 'replay') and args.replay == "generative" and not checkattr(args, 'feedback')) if train_gen: generator = define_autoencoder( args=args, config=config, device='cpu', generator=True, convE=model.convE if hasattr(args, "hidden") and args.hidden else None) # -extract and return param-stamp model_name = model.name replay_model_name = generator.name if train_gen else None param_stamp = get_param_stamp(args, model_name, replay=(hasattr(args, "replay") and not args.replay == "none"), replay_model_name=replay_model_name, verbose=False) return param_stamp
def check_for_errors(args, single_task=False, **kwargs): # -errors in scenario-specification if not single_task: # -if scenario is "class" and XdG is selected, give error if args.scenario == "class" and checkattr(args, 'xdg') and args.xdg_prop > 0: raise ValueError( "Having scenario=[class] with 'XdG' does not make sense") # -if scenario is "domain" and XdG is selected, give warning if args.scenario == "domain" and checkattr( args, 'xdg') and args.xdg_prop > 0: print( "Although scenario=[domain], 'XdG' makes that task identity is nevertheless always required" ) # -if XdG is selected together with replay of any kind, give error if checkattr( args, 'xdg') and args.xdg_prop > 0 and (not args.replay == "none"): raise NotImplementedError( "XdG is not supported with '{}' replay.".format(args.replay)) #--> problem is that applying different task-masks interferes with gradient calculation # (should be possible to overcome by calculating each gradient before applying next mask) # -if 'only_last' is selected with replay, EWC or SI, give error if checkattr(args, 'only_last') and (not args.replay == "none"): raise NotImplementedError( "Option 'only_last' is not supported with '{}' replay.".format( args.replay)) if checkattr(args, 'only_last') and (checkattr(args, 'ewc') and args.ewc_lambda > 0): raise NotImplementedError( "Option 'only_last' is not supported with EWC.") if checkattr(args, 'only_last') and (checkattr(args, 'si') and args.si_c > 0): raise NotImplementedError( "Option 'only_last' is not supported with SI.") # -error in type of reconstruction loss if checkattr(args, "normalize") and hasattr( args, "recon_los") and args.recon_loss == "BCE": raise ValueError( "'BCE' is not a valid reconstruction loss with normalized images")
def init_params(model, args): # - reinitialize all parameters according to default initialization model.apply(utils.weight_reset) # - initialize parameters according to chosen custom initialization (if requested) if hasattr(args, 'init_weight') and not args.init_weight == "standard": utils.weight_init(model, strategy="xavier_normal") if hasattr(args, 'init_bias') and not args.init_bias == "standard": utils.bias_init(model, strategy="constant", value=0.01) # - use pre-trained weights in conv-layers load_name = "{}-e100".format(model.convE.name) utils.load_checkpoint(model.convE, model_dir='./conv_layers', name=load_name) # - freeze weights of conv-layers? if utils.checkattr(args, "bir"): for param in model.convE.parameters(): param.requires_grad = False return model
def show_reconstruction(model, dataset, config, pdf=None, visdom=None, size=32, epoch=None, task=None, no_task_mask=False): '''Plot reconstructed examples by an auto-encoder [model] on [dataset], either in [pdf] and/or in [visdom].''' # Get device-type / using cuda? cuda = model._is_on_cuda() device = model._device() # Set model to evaluation-mode model.eval() # Get data data_loader = utils.get_data_loader(dataset, size, cuda=cuda) (data, labels) = next(iter(data_loader)) # If needed, apply correct specific task-mask (for fully-connected hidden layers in encoder) if hasattr(model, "mask_dict") and model.mask_dict is not None: if no_task_mask: model.reset_XdGmask() else: model.apply_XdGmask(task=task) # Evaluate model data, labels = data.to(device), labels.to(device) with torch.no_grad(): gate_input = (torch.tensor(np.repeat(task - 1, size)).to(device) if model.dg_type == "task" else labels) if (utils.checkattr( model, 'dg_gates') and model.dg_prop > 0) else None recon_output = model(data, gate_input=gate_input, full=True, reparameterize=False) recon_batch = recon_output[0] # Plot original and reconstructed images # -number of rows nrow = int(np.ceil(np.sqrt(size * 2))) # -collect and arrange pixel-values comparison = torch.cat([ data.view(-1, config['channels'], config['size'], config['size'])[:size], recon_batch.view(-1, config['channels'], config['size'], config['size'])[:size] ]).cpu() image_tensor = comparison.view(-1, config['channels'], config['size'], config['size']) # -denormalize images if needed if config['normalize']: image_tensor = config['denormalize'](image_tensor).clamp(min=0, max=1) # -make plots if pdf is not None: epoch_stm = "" if epoch is None else " after epoch ".format(epoch) task_stm = "" if task is None else " (task {})".format(task) visual.plt.plot_images_from_tensor(image_tensor, pdf, nrow=nrow, title="Reconstructions" + task_stm + epoch_stm) if visdom is not None: visual.visdom.visualize_images( tensor=image_tensor, title='Reconstructions ({})'.format(visdom["graph"]), env=visdom["env"], nrow=nrow, )
def set_defaults(args, only_MNIST=False, single_task=False, generative=True, compare_code='None', **kwargs): # -if 'brain-inspired' is selected, select corresponding defaults if checkattr(args, 'brain_inspired'): if hasattr(args, "replay") and not args.replay == "generative": raise Warning( "To run with brain-inspired replay, select both '--brain-inspired' and '--replay=generative'" ) args.feedback = True #--> replay-through-feedback args.prior = 'GMM' #--> conditional replay args.per_class = True #--> conditional replay args.dg_gates = True #--> gating based on internal context (has hyper-param 'dg_prop') args.hidden = True #--> internal replay args.pre_convE = True #--> internal replay args.freeze_convE = True #--> internal replay args.distill = True #--> distillation # -set default-values for certain arguments based on chosen experiment args.normalize = args.normalize if args.experiment in [ "CIFAR100", "ASL", "fruits360", "chars74k", "gtsrb" ] else False args.augment = args.augment if args.experiment in ('CIFAR10', 'CIFAR100') else False if hasattr(args, "depth"): args.depth = (5 if args.experiment in [ "CIFAR100", "ASL", "fruits360", "chars74k", "gtsrb" ] else 0) if args.depth is None else args.depth if hasattr(args, "recon_loss"): args.recon_loss = ("MSE" if args.experiment in [ "CIFAR100", "ASL", "fruits360", "chars74k", "gtsrb" ] else "BCE") if args.recon_loss is None else args.recon_loss if hasattr(args, "dg_type"): args.dg_type = ("task" if args.experiment == 'permMNIST' else "class") if args.dg_type is None else args.dg_type if not single_task: args.tasks = (5 if args.experiment in ['splitMNIST', 'fmnist'] else (10 if args.experiment in [ "CIFAR100", "ASL", "fruits360", "chars74k", "gtsrb" ] else 100)) if args.tasks is None else args.tasks args.iters = (5000 if args.experiment in [ "CIFAR100", "ASL", "fruits360", "chars74k", "gtsrb" ] else 2000) if args.iters is None else args.iters args.lr = (0.001 if args.experiment in ['splitMNIST', 'fmnist'] else 0.0001) if args.lr is None else args.lr args.batch = (128 if args.experiment in ['splitMNIST', 'fmnist'] else 256) if args.batch is None else args.batch args.fc_units = (400 if args.experiment in ['splitMNIST', 'fmnist'] else 2000) if args.fc_units is None else args.fc_units # -set hyper-parameter values (typically found by grid-search) based on chosen experiment & scenario if not single_task and not compare_code in ('hyper', 'bir'): if args.experiment in ['splitMNIST', 'fmnist']: args.xdg_prop = 0.9 if args.scenario == "task" and args.xdg_prop is None else args.xdg_prop args.si_c = (10. if args.scenario == 'task' else 0.1) if args.si_c is None else args.si_c args.ewc_lambda = ( 1000000000. if args.scenario == 'task' else 100000.) if args.ewc_lambda is None else args.ewc_lambda args.gamma = 1. if args.gamma is None else args.gamma if hasattr(args, 'dg_prop'): args.dg_prop = 0.8 if args.dg_prop is None else args.dg_prop elif args.experiment in [ "CIFAR100", "ASL", "fruits360", "chars74k", "gtsrb" ]: args.xdg_prop = 0.7 if args.scenario == "task" and args.xdg_prop is None else args.xdg_prop args.si_c = (100. if args.scenario == 'task' else 1.) if args.si_c is None else args.si_c args.ewc_lambda = ( 1000. if args.scenario == 'task' else 1.) if args.ewc_lambda is None else args.ewc_lambda args.gamma = 1 if args.gamma is None else args.gamma args.dg_prop = (0. if args.scenario == "task" else 0.7) if args.dg_prop is None else args.dg_prop if compare_code == "all": args.dg_si_prop = 0.6 if args.dg_si_prop is None else args.dg_si_prop args.dg_c = 100000000. if args.dg_c is None else args.dg_c elif args.experiment == 'permMNIST': args.si_c = 10. if args.si_c is None else args.si_c args.ewc_lambda = 1. if args.ewc_lambda is None else args.ewc_lambda if hasattr(args, 'o_lambda'): args.o_lambda = 1. if args.o_lambda is None else args.o_lambda args.gamma = 1. if args.gamma is None else args.gamma args.dg_prop = 0.8 if args.dg_prop is None else args.dg_prop if compare_code == "all": args.dg_si_prop = 0.8 if args.dg_si_prop is None else args.dg_si_prop args.dg_c = 1. if args.dg_c is None else args.dg_c # -for other unselected options, set default values (not specific to chosen scenario / experiment) args.h_dim = args.fc_units if args.h_dim is None else args.h_dim if hasattr(args, "lr_gen"): args.lr_gen = args.lr if args.lr_gen is None else args.lr_gen if hasattr(args, "rl"): args.rl = args.depth - 1 if args.rl is None else args.rl if generative and not single_task: if hasattr(args, 'g_iters'): args.g_iters = args.iters if args.g_iters is None else args.g_iters if hasattr(args, 'g_depth') and not only_MNIST: args.g_depth = args.depth if args.g_depth is None else args.g_depth if hasattr(args, 'g_fc_lay'): args.g_fc_lay = args.fc_lay if args.g_fc_lay is None else args.g_fc_lay if hasattr(args, 'g_fc_uni'): args.g_fc_uni = args.fc_units if args.g_fc_uni is None else args.g_fc_uni if hasattr(args, "g_h_dim"): args.g_h_dim = args.g_fc_uni if args.g_h_dim is None else args.g_h_dim if (not single_task) and (not compare_code in ('hyper')): args.xdg_prop = 0. if args.scenario == "task" and args.xdg_prop is None else args.xdg_prop # -if [log_per_task] (which is default for comparison-scripts), reset all logs if not single_task: args.log_per_task = True if ( not compare_code == "none") else args.log_per_task if checkattr(args, 'log_per_task'): args.prec_log = args.iters args.loss_log = args.iters args.sample_log = args.iters return args
def run(args, verbose=False): # Create plots- and results-directories if needed if not os.path.isdir(args.r_dir): os.mkdir(args.r_dir) if args.pdf and not os.path.isdir(args.p_dir): os.mkdir(args.p_dir) # If only want param-stamp, get it and exit if args.get_stamp: from param_stamp import get_param_stamp_from_args print(get_param_stamp_from_args(args=args)) exit() # Use cuda? cuda = torch.cuda.is_available() and args.cuda device = torch.device("cuda" if cuda else "cpu") # Report whether cuda is used if verbose: print("CUDA is {}used".format("" if cuda else "NOT(!!) ")) # Set random seeds np.random.seed(args.seed) torch.manual_seed(args.seed) if cuda: torch.cuda.manual_seed(args.seed) #-------------------------------------------------------------------------------------------------# #----------------# #----- DATA -----# #----------------# # Prepare data for chosen experiment if verbose: print("\nPreparing the data...") (train_datasets, test_datasets), config, classes_per_task = get_multitask_experiment( name=args.experiment, scenario=args.scenario, tasks=args.tasks, data_dir=args.d_dir, normalize=True if utils.checkattr(args, "normalize") else False, augment=True if utils.checkattr(args, "augment") else False, verbose=verbose, exception=True if args.seed < 10 else False, only_test=(not args.train)) #-------------------------------------------------------------------------------------------------# #----------------------# #----- MAIN MODEL -----# #----------------------# # Define main model (i.e., classifier, if requested with feedback connections) if verbose and (utils.checkattr(args, "pre_convE") or utils.checkattr(args, "pre_convD")) and \ (hasattr(args, "depth") and args.depth>0): print("\nDefining the model...") if utils.checkattr(args, 'feedback'): model = define.define_autoencoder(args=args, config=config, device=device) else: model = define.define_classifier(args=args, config=config, device=device) # Initialize / use pre-trained / freeze model-parameters # - initialize (pre-trained) parameters model = define.init_params(model, args) # - freeze weights of conv-layers? if utils.checkattr(args, "freeze_convE"): for param in model.convE.parameters(): param.requires_grad = False if utils.checkattr(args, 'feedback') and utils.checkattr( args, "freeze_convD"): for param in model.convD.parameters(): param.requires_grad = False # Define optimizer (only optimize parameters that "requires_grad") model.optim_list = [ { 'params': filter(lambda p: p.requires_grad, model.parameters()), 'lr': args.lr }, ] model.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999)) #-------------------------------------------------------------------------------------------------# #----------------------------------------------------# #----- CL-STRATEGY: REGULARIZATION / ALLOCATION -----# #----------------------------------------------------# # Elastic Weight Consolidation (EWC) if isinstance(model, ContinualLearner) and utils.checkattr(args, 'ewc'): model.ewc_lambda = args.ewc_lambda if args.ewc else 0 model.fisher_n = args.fisher_n model.online = utils.checkattr(args, 'online') if model.online: model.gamma = args.gamma # Synpatic Intelligence (SI) if isinstance(model, ContinualLearner) and utils.checkattr(args, 'si'): model.si_c = args.si_c if args.si else 0 model.epsilon = args.epsilon # XdG: create for every task a "mask" for each hidden fully connected layer if isinstance(model, ContinualLearner) and utils.checkattr( args, 'xdg') and args.xdg_prop > 0: model.define_XdGmask(gating_prop=args.xdg_prop, n_tasks=args.tasks) #-------------------------------------------------------------------------------------------------# #-------------------------------# #----- CL-STRATEGY: REPLAY -----# #-------------------------------# # Use distillation loss (i.e., soft targets) for replayed data? (and set temperature) if isinstance(model, ContinualLearner) and hasattr( args, 'replay') and not args.replay == "none": model.replay_targets = "soft" if args.distill else "hard" model.KD_temp = args.temp # If needed, specify separate model for the generator train_gen = (hasattr(args, 'replay') and args.replay == "generative" and not utils.checkattr(args, 'feedback')) if train_gen: # Specify architecture generator = define.define_autoencoder(args, config, device, generator=True) # Initialize parameters generator = define.init_params(generator, args) # -freeze weights of conv-layers? if utils.checkattr(args, "freeze_convE"): for param in generator.convE.parameters(): param.requires_grad = False if utils.checkattr(args, "freeze_convD"): for param in generator.convD.parameters(): param.requires_grad = False # Set optimizer(s) generator.optim_list = [ { 'params': filter(lambda p: p.requires_grad, generator.parameters()), 'lr': args.lr_gen if hasattr(args, 'lr_gen') else args.lr }, ] generator.optimizer = optim.Adam(generator.optim_list, betas=(0.9, 0.999)) else: generator = None #-------------------------------------------------------------------------------------------------# #---------------------# #----- REPORTING -----# #---------------------# # Get parameter-stamp (and print on screen) if verbose: print("\nParameter-stamp...") param_stamp = get_param_stamp( args, model.name, verbose=verbose, replay=True if (hasattr(args, 'replay') and not args.replay == "none") else False, replay_model_name=generator.name if (hasattr(args, 'replay') and args.replay in ("generative") and not utils.checkattr(args, 'feedback')) else None, ) # Print some model-characteristics on the screen if verbose: # -main model utils.print_model_info(model, title="MAIN MODEL") # -generator if generator is not None: utils.print_model_info(generator, title="GENERATOR") # Define [progress_dicts] to keep track of performance during training for storing and for later plotting in pdf precision_dict = evaluate.initiate_precision_dict(args.tasks) # Prepare for plotting in visdom visdom = None if args.visdom: env_name = "{exp}{tasks}-{scenario}".format(exp=args.experiment, tasks=args.tasks, scenario=args.scenario) replay_statement = "{mode}{fb}{con}{gat}{int}{dis}{b}{u}".format( mode=args.replay, fb="Rtf" if utils.checkattr(args, "feedback") else "", con="Con" if (hasattr(args, "prior") and args.prior == "GMM" and utils.checkattr(args, "per_class")) else "", gat="Gat{}".format(args.dg_prop) if (utils.checkattr(args, "dg_gates") and hasattr(args, "dg_prop") and args.dg_prop > 0) else "", int="Int" if utils.checkattr(args, "hidden") else "", dis="Dis" if args.replay == "generative" and args.distill else "", b="" if (args.batch_replay is None or args.batch_replay == args.batch) else "-br{}".format(args.batch_replay), u="" if args.g_fc_uni == args.fc_units else "-gu{}".format( args.g_fc_uni)) if (hasattr(args, "replay") and not args.replay == "none") else "NR" graph_name = "{replay}{syn}{ewc}{xdg}".format( replay=replay_statement, syn="-si{}".format(args.si_c) if utils.checkattr(args, 'si') else "", ewc="-ewc{}{}".format( args.ewc_lambda, "-O{}".format(args.gamma) if utils.checkattr(args, "online") else "") if utils.checkattr( args, 'ewc') else "", xdg="" if (not utils.checkattr(args, 'xdg')) or args.xdg_prop == 0 else "-XdG{}".format(args.xdg_prop), ) visdom = {'env': env_name, 'graph': graph_name} #-------------------------------------------------------------------------------------------------# #---------------------# #----- CALLBACKS -----# #---------------------# g_iters = args.g_iters if hasattr(args, 'g_iters') else args.iters # Callbacks for reporting on and visualizing loss generator_loss_cbs = [ cb._VAE_loss_cb( log=args.loss_log, visdom=visdom, replay=(hasattr(args, "replay") and not args.replay == "none"), model=model if utils.checkattr(args, 'feedback') else generator, tasks=args.tasks, iters_per_task=args.iters if utils.checkattr(args, 'feedback') else g_iters) ] if (train_gen or utils.checkattr(args, 'feedback')) else [None] solver_loss_cbs = [ cb._solver_loss_cb(log=args.loss_log, visdom=visdom, model=model, iters_per_task=args.iters, tasks=args.tasks, replay=(hasattr(args, "replay") and not args.replay == "none")) ] if (not utils.checkattr(args, 'feedback')) else [None] # Callbacks for evaluating and plotting generated / reconstructed samples no_samples = (utils.checkattr(args, "no_samples") or (utils.checkattr(args, "hidden") and hasattr(args, 'depth') and args.depth > 0)) sample_cbs = [ cb._sample_cb(log=args.sample_log, visdom=visdom, config=config, test_datasets=test_datasets, sample_size=args.sample_n, iters_per_task=g_iters) ] if ((train_gen or utils.checkattr(args, 'feedback')) and not no_samples) else [None] # Callbacks for reporting and visualizing accuracy, and visualizing representation extracted by main model # -visdom (i.e., after each [prec_log] eval_cb = cb._eval_cb( log=args.prec_log, test_datasets=test_datasets, visdom=visdom, precision_dict=None, iters_per_task=args.iters, test_size=args.prec_n, classes_per_task=classes_per_task, scenario=args.scenario, ) # -pdf / reporting: summary plots (i.e, only after each task) eval_cb_full = cb._eval_cb( log=args.iters, test_datasets=test_datasets, precision_dict=precision_dict, iters_per_task=args.iters, classes_per_task=classes_per_task, scenario=args.scenario, ) # -visualize feature space latent_space_cb = cb._latent_space_cb( log=args.iters, datasets=test_datasets, visdom=visdom, iters_per_task=args.iters, sample_size=400, ) # -collect them in <lists> eval_cbs = [eval_cb, eval_cb_full, latent_space_cb] #-------------------------------------------------------------------------------------------------# #--------------------# #----- TRAINING -----# #--------------------# if args.train: if verbose: print("\nTraining...") # Train model train_cl( model, train_datasets, replay_mode=args.replay if hasattr(args, 'replay') else "none", scenario=args.scenario, classes_per_task=classes_per_task, iters=args.iters, batch_size=args.batch, batch_size_replay=args.batch_replay if hasattr( args, 'batch_replay') else None, generator=generator, gen_iters=g_iters, gen_loss_cbs=generator_loss_cbs, feedback=utils.checkattr(args, 'feedback'), sample_cbs=sample_cbs, eval_cbs=eval_cbs, loss_cbs=generator_loss_cbs if utils.checkattr(args, 'feedback') else solver_loss_cbs, args=args, reinit=utils.checkattr(args, 'reinit'), only_last=utils.checkattr(args, 'only_last')) # Save evaluation metrics measured throughout training file_name = "{}/dict-{}".format(args.r_dir, param_stamp) utils.save_object(precision_dict, file_name) # Save trained model(s), if requested if args.save: save_name = "mM-{}".format(param_stamp) if ( not hasattr(args, 'full_stag') or args.full_stag == "none") else "{}-{}".format( model.name, args.full_stag) utils.save_checkpoint(model, args.m_dir, name=save_name, verbose=verbose) if generator is not None: save_name = "gM-{}".format(param_stamp) if ( not hasattr(args, 'full_stag') or args.full_stag == "none") else "{}-{}".format( generator.name, args.full_stag) utils.save_checkpoint(generator, args.m_dir, name=save_name, verbose=verbose) else: # Load previously trained model(s) (if goal is to only evaluate previously trained model) if verbose: print("\nLoading parameters of the previously trained models...") load_name = "mM-{}".format(param_stamp) if ( not hasattr(args, 'full_ltag') or args.full_ltag == "none") else "{}-{}".format( model.name, args.full_ltag) utils.load_checkpoint( model, args.m_dir, name=load_name, verbose=verbose, add_si_buffers=(isinstance(model, ContinualLearner) and utils.checkattr(args, 'si'))) if generator is not None: load_name = "gM-{}".format(param_stamp) if ( not hasattr(args, 'full_ltag') or args.full_ltag == "none") else "{}-{}".format( generator.name, args.full_ltag) utils.load_checkpoint(generator, args.m_dir, name=load_name, verbose=verbose) #-------------------------------------------------------------------------------------------------# #-----------------------------------# #----- EVALUATION of CLASSIFIER-----# #-----------------------------------# if verbose: print("\n\nEVALUATION RESULTS:") # Evaluate precision of final model on full test-set precs = [ evaluate.validate( model, test_datasets[i], verbose=False, test_size=None, task=i + 1, allowed_classes=list( range(classes_per_task * i, classes_per_task * (i + 1))) if args.scenario == "task" else None) for i in range(args.tasks) ] average_precs = sum(precs) / args.tasks # -print on screen if verbose: print("\n Accuracy of final model on test-set:") for i in range(args.tasks): print(" - {} {}: {:.4f}".format( "For classes from task" if args.scenario == "class" else "Task", i + 1, precs[i])) print('=> Average accuracy over all {} {}: {:.4f}\n'.format( args.tasks * classes_per_task if args.scenario == "class" else args.tasks, "classes" if args.scenario == "class" else "tasks", average_precs)) # -write out to text file output_file = open("{}/prec-{}.txt".format(args.r_dir, param_stamp), 'w') output_file.write('{}\n'.format(average_precs)) output_file.close() #-------------------------------------------------------------------------------------------------# #-----------------------------------# #----- EVALUATION of GENERATOR -----# #-----------------------------------# if (utils.checkattr(args, 'feedback') or train_gen ) and args.experiment == "CIFAR100" and args.scenario == "class": # Dataset and model to be used test_set = ConcatDataset(test_datasets) gen_model = model if utils.checkattr(args, 'feedback') else generator gen_model.eval() # Evaluate log-likelihood of generative model on combined test-set (with S=100 importance samples per datapoint) ll_per_datapoint = gen_model.estimate_loglikelihood( test_set, S=100, batch_size=args.batch) if verbose: print('=> Log-likelihood on test set: {:.4f} +/- {:.4f}\n'.format( np.mean(ll_per_datapoint), np.sqrt(np.var(ll_per_datapoint)))) # -write out to text file output_file = open("{}/ll-{}.txt".format(args.r_dir, param_stamp), 'w') output_file.write('{}\n'.format(np.mean(ll_per_datapoint))) output_file.close() # Evaluate reconstruction error (averaged over number of input units) re_per_datapoint = gen_model.calculate_recon_error( test_set, batch_size=args.batch, average=True) if verbose: print( '=> Reconstruction error (per input unit) on test set: {:.4f} +/- {:.4f}\n' .format(np.mean(re_per_datapoint), np.sqrt(np.var(re_per_datapoint)))) # -write out to text file output_file = open("{}/re-{}.txt".format(args.r_dir, param_stamp), 'w') output_file.write('{}\n'.format(np.mean(re_per_datapoint))) output_file.close() # Try loading the classifier (our substitute for InceptionNet) for calculating IS, FID and Recall & Precision # -define model config['classes'] = 100 pretrained_classifier = define.define_classifier(args=args, config=config, device=device) pretrained_classifier.hidden = False # -load pretrained weights eval_tag = "" if args.eval_tag == "none" else "-{}".format( args.eval_tag) try: utils.load_checkpoint(pretrained_classifier, args.m_dir, verbose=True, name="{}{}".format( pretrained_classifier.name, eval_tag)) FileFound = True except FileNotFoundError: if verbose: print("= Could not find model {}{} in {}".format( pretrained_classifier.name, eval_tag, args.m_dir)) print("= IS, FID and Precision & Recall not computed!") FileFound = False pretrained_classifier.eval() # Only continue with computing these measures if the requested classifier network (using --eval-tag) was found if FileFound: # Preparations total_n = len(test_set) n_repeats = int(np.ceil(total_n / args.batch)) # -sample data from generator (for IS, FID and Precision & Recall) gen_x = gen_model.sample(size=total_n, only_x=True) # -generate predictions for generated data (for IS) gen_pred = [] for i in range(n_repeats): x = gen_x[(i * args.batch):int(min(((i + 1) * args.batch), total_n))] with torch.no_grad(): gen_pred.append( F.softmax(pretrained_classifier.hidden_to_output(x) if args.hidden else pretrained_classifier(x), dim=1).cpu().numpy()) gen_pred = np.concatenate(gen_pred) # -generate embeddings for generated data (for FID and Precision & Recall) gen_emb = [] for i in range(n_repeats): with torch.no_grad(): gen_emb.append( pretrained_classifier.feature_extractor( gen_x[(i * args.batch ):int(min(((i + 1) * args.batch), total_n))], from_hidden=args.hidden).cpu().numpy()) gen_emb = np.concatenate(gen_emb) # -generate embeddings for test data (for FID and Precision & Recall) data_loader = utils.get_data_loader(test_set, batch_size=args.batch, cuda=cuda) real_emb = [] for real_x, _ in data_loader: with torch.no_grad(): real_emb.append( pretrained_classifier.feature_extractor( real_x.to(device)).cpu().numpy()) real_emb = np.concatenate(real_emb) # Calculate "Inception Score" (IS) py = gen_pred.mean(axis=0) is_per_datapoint = [] for i in range(len(gen_pred)): pyx = gen_pred[i, :] is_per_datapoint.append(entropy(pyx, py)) IS = np.exp(np.mean(is_per_datapoint)) if verbose: print('=> Inception Score = {:.4f}\n'.format(IS)) # -write out to text file output_file = open( "{}/is{}-{}.txt".format(args.r_dir, eval_tag, param_stamp), 'w') output_file.write('{}\n'.format(IS)) output_file.close() ## Calculate "Frechet Inception Distance" (FID) FID = fid.calculate_fid_from_embedding(gen_emb, real_emb) if verbose: print('=> Frechet Inception Distance = {:.4f}\n'.format(FID)) # -write out to text file output_file = open( "{}/fid{}-{}.txt".format(args.r_dir, eval_tag, param_stamp), 'w') output_file.write('{}\n'.format(FID)) output_file.close() # Calculate "Precision & Recall"-curves precision, recall = pr.compute_prd_from_embedding( gen_emb, real_emb) # -write out to text files file_name = "{}/precision{}-{}.txt".format(args.r_dir, eval_tag, param_stamp) with open(file_name, 'w') as f: for item in precision: f.write("%s\n" % item) file_name = "{}/recall{}-{}.txt".format(args.r_dir, eval_tag, param_stamp) with open(file_name, 'w') as f: for item in recall: f.write("%s\n" % item) #-------------------------------------------------------------------------------------------------# #--------------------# #----- PLOTTING -----# #--------------------# # If requested, generate pdf if args.pdf: # -open pdf plot_name = "{}/{}.pdf".format(args.p_dir, param_stamp) pp = evaluate.visual.plt.open_pdf(plot_name) # -show metrics reflecting progression during training if args.train and (not utils.checkattr(args, 'only_last')): # -create list to store all figures to be plotted. figure_list = [] # -generate figures (and store them in [figure_list]) figure = evaluate.visual.plt.plot_lines( precision_dict["all_tasks"], x_axes=[ i * classes_per_task for i in precision_dict["x_task"] ] if args.scenario == "class" else precision_dict["x_task"], line_names=[ '{} {}'.format( "episode / task" if args.scenario == "class" else "task", i + 1) for i in range(args.tasks) ], xlabel="# of {}s so far".format("classe" if args.scenario == "class" else "task"), ylabel="Test accuracy") figure_list.append(figure) figure = evaluate.visual.plt.plot_lines( [precision_dict["average"]], x_axes=[ i * classes_per_task for i in precision_dict["x_task"] ] if args.scenario == "class" else precision_dict["x_task"], line_names=[ 'Average based on all {}s so far'.format(( "digit" if args.experiment == "splitMNIST" else "classe") if args.scenario else "task") ], xlabel="# of {}s so far".format("classe" if args.scenario == "class" else "task"), ylabel="Test accuracy") figure_list.append(figure) # -add figures to pdf for figure in figure_list: pp.savefig(figure) gen_eval = (utils.checkattr(args, 'feedback') or train_gen) # -show samples (from main model or separate generator) if gen_eval and not no_samples: evaluate.show_samples( model if utils.checkattr(args, 'feedback') else generator, config, size=args.sample_n, pdf=pp, title="Generated samples (by final model)") # -plot "Precision & Recall"-curve if gen_eval and args.experiment == "CIFAR100" and args.scenario == "class" and FileFound: figure = evaluate.visual.plt.plot_pr_curves([[precision]], [[recall]]) pp.savefig(figure) # -close pdf pp.close() # -print name of generated plot on screen if verbose: print("\nGenerated plot: {}\n".format(plot_name))
def train_cl(model, train_datasets, replay_mode="none", scenario="task", rnt=None, classes_per_task=None, iters=2000, batch_size=32, batch_size_replay=None, loss_cbs=list(), eval_cbs=list(), sample_cbs=list(), generator=None, gen_iters=0, gen_loss_cbs=list(), feedback=False, reinit=False, args=None, only_last=False, sample_method='random', curated_multiplier=4): '''Train a model (with a "train_a_batch" method) on multiple tasks, with replay-strategy specified by [replay_mode]. [model] <nn.Module> main model to optimize across all tasks [train_datasets] <list> with for each task the training <DataSet> [replay_mode] <str>, choice from "generative", "current", "offline" and "none" [scenario] <str>, choice from "task", "domain", "class" and "all" [classes_per_task] <int>, # classes per task; only 1st task has [classes_per_task]*[first_task_class_boost] classes [rnt] <float>, indicating relative importance of new task (if None, relative to # old tasks) [iters] <int>, # optimization-steps (=batches) per task; 1st task has [first_task_iter_boost] steps more [batch_size_replay] <int>, number of samples to replay per batch [generator] None or <nn.Module>, if a seperate generative model should be trained (for [gen_iters] per task) [feedback] <bool>, if True and [replay_mode]="generative", the main model is used for generating replay [only_last] <bool>, only train on final task / episode [*_cbs] <list> of call-back functions to evaluate training-progress [sample_method] <str> indicating the sample method, choices: 'random', 'uniform', 'curated', 'softmax', 'interfered', 'misclassified' [curated_multiplier]<int> choose curated samples out of size curated_multiplier * mutiply batch_size_replay ''' # Should convolutional layers be frozen? freeze_convE = (utils.checkattr(args, "freeze_convE") and hasattr(args, "depth") and args.depth>0) # Use cuda? device = model._device() cuda = model._is_on_cuda() # Set default-values if not specified batch_size_replay = batch_size if batch_size_replay is None else batch_size_replay # Initiate indicators for replay (no replay for 1st task) Generative = Current = Offline_TaskIL = False previous_model = None # Register starting param-values (needed for "intelligent synapses"). if isinstance(model, ContinualLearner) and model.si_c>0: for n, p in model.named_parameters(): if p.requires_grad: n = n.replace('.', '__') model.register_buffer('{}_SI_prev_task'.format(n), p.detach().clone()) # Loop over all tasks. for task, train_dataset in enumerate(train_datasets, 1): # If offline replay-setting, create large database of all tasks so far if replay_mode=="offline" and (not scenario=="task"): train_dataset = ConcatDataset(train_datasets[:task]) # -but if "offline"+"task": all tasks so far should be visited separately (i.e., separate data-loader per task) if replay_mode=="offline" and scenario=="task": Offline_TaskIL = True data_loader = [None]*task # Initialize # iters left on data-loader(s) iters_left = 1 if (not Offline_TaskIL) else [1]*task # Prepare <dicts> to store running importance estimates and parameter-values before update if isinstance(model, ContinualLearner) and model.si_c>0: W = {} p_old = {} for n, p in model.named_parameters(): if p.requires_grad: n = n.replace('.', '__') W[n] = p.data.clone().zero_() p_old[n] = p.data.clone() # Find [active_classes] (=classes in current task) active_classes = None #-> for "domain"- or "all"-scenarios, always all classes are active if scenario=="task": # -for "task"-scenario, create <list> with for all tasks so far a <list> with the active classes active_classes = [list(range(classes_per_task*i, classes_per_task*(i+1))) for i in range(task)] elif scenario=="class": # -for "class"-scenario, create one <list> with active classes of all tasks so far active_classes = list(range(classes_per_task*task)) # Reinitialize the model's parameters (if requested) if reinit: from define_models import init_params init_params(model, args) if generator is not None: init_params(generator, args) # Define a tqdm progress bar(s) iters_main = iters progress = tqdm.tqdm(range(1, iters_main+1)) if generator is not None: iters_gen = gen_iters progress_gen = tqdm.tqdm(range(1, iters_gen+1)) # Loop over all iterations iters_to_use = (iters_main if (generator is None) else max(iters_main, iters_gen)) # -if only the final task should be trained on: if only_last and not task==len(train_datasets): iters_to_use = 0 # This helps w/ speeding up curated_classVariety mask = None if (sample_method=="curated_classVariety" and (task-1)>0): sampleAmt = batch_size_replay * curated_multiplier classNum = classes_per_task*(task-1) indexList = [[idx for idx in range(sampleAmt) if (idx%classNum) == (rowIdx%classNum)] for rowIdx in range(sampleAmt)] mask = [] for rowIdxList in indexList: curRow = [0] * sampleAmt for idx in rowIdxList: curRow[idx] = 1 mask.append(curRow) mask = torch.tensor(mask, dtype=torch.float).to(device) for batch_index in range(1, iters_to_use+1): # Update # iters left on current data-loader(s) and, if needed, create new one(s) if not Offline_TaskIL: iters_left -= 1 if iters_left==0: data_loader = iter(utils.get_data_loader(train_dataset, batch_size, cuda=cuda, drop_last=True)) iters_left = len(data_loader) else: # -with "offline replay" in Task-IL scenario, there is a separate data-loader for each task batch_size_to_use = int(np.ceil(batch_size/task)) for task_id in range(task): iters_left[task_id] -= 1 if iters_left[task_id]==0: data_loader[task_id] = iter(utils.get_data_loader( train_datasets[task_id], batch_size_to_use, cuda=cuda, drop_last=True )) iters_left[task_id] = len(data_loader[task_id]) #-----------------Collect data------------------# #####-----CURRENT BATCH-----##### if not Offline_TaskIL: x, y = next(data_loader) #--> sample training data of current task y = y-classes_per_task*(task-1) if scenario=="task" else y #--> ITL: adjust y-targets to 'active range' x, y = x.to(device), y.to(device) #--> transfer them to correct device #y = y.expand(1) if len(y.size())==1 else y #--> hack for if batch-size is 1 else: x = y = task_used = None #--> all tasks are "treated as replay" # -sample training data for all tasks so far, move to correct device and store in lists x_, y_ = list(), list() for task_id in range(task): x_temp, y_temp = next(data_loader[task_id]) x_.append(x_temp.to(device)) y_temp = y_temp - (classes_per_task * task_id) #--> adjust y-targets to 'active range' if batch_size_to_use == 1: y_temp = torch.tensor([y_temp]) #--> correct dimensions if batch-size is 1 y_.append(y_temp.to(device)) #####-----REPLAYED BATCH-----##### if not Offline_TaskIL and not Generative and not Current: x_ = y_ = scores_ = task_used = None #-> if no replay #--------------------------------------------INPUTS----------------------------------------------------# ##-->> Current Replay <<--## if Current: x_ = x[:batch_size_replay] #--> use current task inputs task_used = None ##-->> Generative Replay <<--## if Generative: #---> Only with generative replay, the resulting [x_] will be at the "hidden"-level conditional_gen = True if ( (previous_generator.per_class and previous_generator.prior=="GMM") or utils.checkattr(previous_generator, 'dg_gates') ) else False # Sample [x_] if conditional_gen and scenario=="task": # -if a conditional generator is used with task-IL scenario, generate data per previous task x_ = list() task_used = list() for task_id in range(task-1): allowed_classes = list(range(classes_per_task*task_id, classes_per_task*(task_id+1))) batch_size_replay_to_use = int(np.ceil(batch_size_replay / (task-1))) x_temp_ = previous_generator.sample(batch_size_replay_to_use, allowed_classes=allowed_classes, only_x=False) x_.append(x_temp_[0]) task_used.append(x_temp_[2]) else: # -which classes are allowed to be generated? (relevant if conditional generator / decoder-gates) allowed_classes = None if scenario=="domain" else list(range(classes_per_task*(task-1))) # -which tasks/domains are allowed to be generated? (only relevant if "Domain-IL" with task-gates) allowed_domains = list(range(task-1)) # -generate inputs representative of previous tasks # --- SAMPLE METHOD CHOICES: softmax, random, uniform, curated --- # --- Softmax sampling: use previous model to score images from this new task, generate those classes if sample_method == 'softmax': with torch.no_grad(): curTaskID = task - 2 newScores_og = previous_model.classify(previous_model.input_to_hidden(x), not_hidden=False if Generative else True) newScores = newScores_og[:, :(classes_per_task * (curTaskID + 1))] softmax = torch.nn.Softmax(dim=1) newHardScores = nn.Softmax(dim=1)(newScores) avgError = torch.mean(newHardScores, dim=0) sampleProbs = torch.zeros(newScores_og.shape[1]) sampleProbs[:(classes_per_task * (curTaskID + 1))] = avgError[ :(classes_per_task * (curTaskID + 1))] x_, y_used, task_used = previous_generator.sample( batch_size_replay, allowed_classes=allowed_classes, allowed_domains=allowed_domains, only_x=False, class_probs=sampleProbs,uniform_sampling=False) # --- Uniformly random sampling (baseline) --- elif sample_method == 'random': x_, y_used, task_used = previous_generator.sample( batch_size_replay, allowed_classes=allowed_classes, allowed_domains=allowed_domains, only_x=False, class_probs=None, uniform_sampling=False) # --- Uniform sampling: balanced numbers of samples from each class --- elif sample_method == 'uniform': x_, y_used, task_used = previous_generator.sample( batch_size_replay, allowed_classes=allowed_classes, allowed_domains=allowed_domains, only_x=False, class_probs=None, uniform_sampling=True) # --- Uniform sample curation: pick the best samples to show (by some metric), balance uniformly --- else: if (sample_method == "curated_variety"): # Generate x times as many samples as we need to then pick the best of x_, y_used, task_used, varietyVector = previous_generator.sample( batch_size_replay * curated_multiplier, allowed_classes=allowed_classes, allowed_domains=allowed_domains, only_x=False, class_probs=None, uniform_sampling=False, varietyVector=True) # CURATED USING CLASS VARIETY (i.e., generating batch_size_reply*curated_multipler / len(allowed_classes) samples # per class, where each sample is the "most different" sample based off our variety calculation elif(sample_method == "curated_classVariety"): x_, y_used, task_used, varietyVector = previous_generator.sample( batch_size_replay * curated_multiplier, allowed_classes=allowed_classes, allowed_domains=allowed_domains, only_x=False, class_probs=None, uniform_sampling=True, varietyVector=True, classVariety=True, classVarietyMask=mask) elif(sample_method == "curated_softmax"): with torch.no_grad(): curTaskID = task - 2 newScores_og = previous_model.classify(previous_model.input_to_hidden(x), not_hidden=False if Generative else True) newScores = newScores_og[:, :(classes_per_task * (curTaskID + 1))] softmax = torch.nn.Softmax(dim=1) newHardScores = nn.Softmax(dim=1)(newScores) avgError = torch.mean(newHardScores, dim=0) sampleProbs = torch.zeros(newScores_og.shape[1]) sampleProbs[:(classes_per_task * (curTaskID + 1))] = avgError[ :(classes_per_task * (curTaskID + 1))] # Generate x times as many samples as we need to then pick the best of x_, y_used, task_used = previous_generator.sample( batch_size_replay * curated_multiplier, allowed_classes=allowed_classes, allowed_domains=allowed_domains, only_x=False, class_probs=sampleProbs, uniform_sampling=False) else: # Generate x times as many samples as we need to then pick the best of x_, y_used, task_used = previous_generator.sample( batch_size_replay * curated_multiplier, allowed_classes=allowed_classes, allowed_domains=allowed_domains, only_x=False, class_probs=None, uniform_sampling=False) # --- Measure the performance of each of these samples on the current model --- # Use the previous model to score the generated images (code taken from Trevor's softmax above) with torch.no_grad(): curTaskID = task - 2 newScores_og = model.classify(x_, not_hidden=False if Generative else True) newScores = newScores_og[:, :(classes_per_task * (curTaskID + 1))] # Logits that don't sum to 1 newHardScores = nn.Softmax(dim=1)(newScores) # Makes the scores sum to 1 (probabilities) cross_entropy = nn.CrossEntropyLoss(reduction='none') y_used = torch.tensor(y_used, dtype=torch.long).to(device) cross_entropy_loss = cross_entropy(newHardScores, y_used) # --- Copy the model and perform an update on just the new incoming data (no replayed data) --- # This will lead to catastrophic forgetting, as it has no replays to prevent this from happening model_tmp = copy.deepcopy(model) # NOTE: Can train multiple batches if needed, but it would be on the same data, so any changes will just be exacerbated _ = model_tmp.train_a_batch(x, y=y, x_=None, y_=None, scores_=None, tasks_=task_used, active_classes=active_classes, task=task, rnt=( 1. if task==1 else 1./task ) if rnt is None else rnt, freeze_convE=freeze_convE, replay_not_hidden=False if Generative else True) # --- Measure the performance of each of the generated samples on this updated model --- # This can tell us how much the model 'forgets' each of these samples, we will replay the worst ones with torch.no_grad(): curTaskID = task - 2 newScores_og = model_tmp.classify(x_, not_hidden=False if Generative else True) newScores = newScores_og[:, :(classes_per_task * (curTaskID + 2))] # Logits that don't sum to 1 newHardScores2 = nn.Softmax(dim=1)(newScores) # Makes the scores sum to 1 (probabilities) # --- Measure the difference in cross entropy loss for predictions before and after --- if sample_method == 'curated' or sample_method == "curated_softmax": cross_entropy = nn.CrossEntropyLoss(reduction='none') # Per-example cross entropy (not avg) cross_entropy_loss2 = cross_entropy(newHardScores2, y_used) # Amount that the loss changes between the model updating diff = cross_entropy_loss2 - cross_entropy_loss metric = diff # TREVOR'S NEW METHOD - This tries to take into account the variety of the samples elif sample_method == "curated_variety" or sample_method == "curated_classVariety": cross_entropy = nn.CrossEntropyLoss(reduction='none') # Per-example cross entropy (not avg) cross_entropy_loss2 = cross_entropy(newHardScores2, y_used) # Amount that the loss changes between the model updating diff = cross_entropy_loss2 - cross_entropy_loss # Softmaxing diff and the variety vector varietyWeight = 0.5 diff_SM = nn.Softmax(dim=0)(diff) variety_SM = nn.Softmax(dim=0)(varietyVector) metric = ((1-varietyWeight) * diff_SM) + (varietyWeight * variety_SM) # Multiply the misclassification error (cross entropy) by the amount that this changes between the model updating # metric = cross_entropy_loss2 * diff # --- Measure KL Divergence between predictions before and predictions afterwards --- # Maximally Interfered Retrieval uses a linear combination of KL, entropy, and 'variance' # This ensures the samples are not too close together, but we do not currently measure that elif sample_method == 'interfered': KLDiv = nn.KLDivLoss(reduction='none')(newHardScores, newHardScores2) print(KLDiv.shape) print(cross_entropy_loss.shape) # Test code to compute KL divergence for every example individually, above code is (512, 2) rather than (512, 1) for some reason #KLDiv = [ nn.KLDivLoss()(newHardScores[i], newHardScores2[i]) for i in range(len(newHardScores))] # Note from Trevor: When I tried to run this method, there was a size mismatch. # KLDiv.shape = (1024, 10), whereas cross_entropy_loss.shape = (1024,), and it said # that they needed to be equal on dim 1. Sooo: I tried to just transpose the KLDiv matrix, # and it worked. Honestly, I'm too tired to try and decipher what I did mathematically lol metric = torch.tensor(KLDiv.T) - 0 * cross_entropy_loss # --- New idea: use the examples which the new model misclassifies the most as one of the new classes # This the opposite approach to softmax, where softmax takes the current model and calculates # Which classes does it confuse the new data for the most, this trains on the new data and then # Tries to find generated examples which it confuses for the new data classes the most elif sample_method == 'misclassified' or sample_method == 'uniform_large' or sample_method == 'random_large': metric = newHardScores2[:, -1] + newHardScores2[:, -1] # --- Sort based on some metric, then divide up by classes (afterwards) --- sorted, indices = torch.sort(metric, descending=True) # Descending order, pick first 100 # Shuffle indices around to test choosing from this larger pool of generated samples randomly if sample_method == 'uniform_large' or sample_method == 'random_large': indices2 = indices.cpu().numpy() np.random.shuffle(indices2) indices = torch.from_numpy(indices2).to(device) if sample_method != 'random_large' and sample_method != 'curated_softmax': # --- Calculate how many examples for each class should be generated to divide up uniformly --- # Uniform dist will be [0, 1, 2, 3, 0, 1, 2] for allowed classes=4 and batch_size_replay=7 uniform_dist = torch.arange(batch_size_replay) % len(allowed_classes) counts_each_class = torch.unique(uniform_dist, return_counts=True)[1] # --- Optional: Calculate unbalanced indices to replay, results in poor performance --- # If we added a variation term to ensure samples are different from each other, this could # be a simpler way to do things, but variance would be pretty complicated to calculate #indices_to_replay = indices[:batch_size_replay] # --- Select the top k_i indices for each class i, where k_i is the number of examples for that class --- # Top x most affected of the generated samples for each class (ensures it is balanced, slightly more computation than unbalanced) indices_to_replay = torch.cat(( [ indices[y_used[indices]==i][:counts_each_class[i]] for i in range(len(allowed_classes)) ] )) x_ = x_[indices_to_replay] else: # Uniformly randomly choose from the 400 samples generated x_ = x_[indices] #--------------------------------------------OUTPUTS----------------------------------------------------# if Generative or Current: # Get target scores & possibly labels (i.e., [scores_] / [y_]) -- use previous model, with no_grad() if scenario in ("domain", "class") and previous_model.mask_dict is None: # -if replay does not need to be evaluated for each task (ie, not Task-IL and no task-specific mask) with torch.no_grad(): all_scores_ = previous_model.classify(x_, not_hidden=False if Generative else True) scores_ = all_scores_[:, :(classes_per_task*(task-1))] if ( scenario=="class" ) else all_scores_ # -> when scenario=="class", zero probs will be added in [loss_fn_kd]-function # -also get the 'hard target' _, y_ = torch.max(scores_, dim=1) else: # -[x_] needs to be evaluated according to each previous task, so make list with entry per task scores_ = list() y_ = list() # -if no task-mask and no conditional generator, all scores can be calculated in one go if previous_model.mask_dict is None and not type(x_)==list: with torch.no_grad(): all_scores_ = previous_model.classify(x_, not_hidden=False if Generative else True) for task_id in range(task-1): # -if there is a task-mask (i.e., XdG is used), obtain predicted scores for each task separately if previous_model.mask_dict is not None: previous_model.apply_XdGmask(task=task_id+1) if previous_model.mask_dict is not None or type(x_)==list: with torch.no_grad(): all_scores_ = previous_model.classify(x_[task_id] if type(x_)==list else x_, not_hidden=False if Generative else True) if scenario=="domain": # NOTE: if scenario=domain with task-mask, it's of course actually the Task-IL scenario! # this can be used as trick to run the Task-IL scenario with singlehead output layer temp_scores_ = all_scores_ else: temp_scores_ = all_scores_[:, (classes_per_task*task_id):(classes_per_task*(task_id+1))] scores_.append(temp_scores_) # - also get hard target _, temp_y_ = torch.max(temp_scores_, dim=1) y_.append(temp_y_) # -only keep predicted y_/scores_ if required (as otherwise unnecessary computations will be done) y_ = y_ if (model.replay_targets=="hard") else None scores_ = scores_ if (model.replay_targets=="soft") else None #-----------------Train model(s)------------------# #---> Train MAIN MODEL if batch_index <= iters_main: # Train the main model with this batch loss_dict = model.train_a_batch(x, y=y, x_=x_, y_=y_, scores_=scores_, tasks_=task_used, active_classes=active_classes, task=task, rnt=( 1. if task==1 else 1./task ) if rnt is None else rnt, freeze_convE=freeze_convE, replay_not_hidden=False if Generative else True) # UNIFORM SAMPLE CURATION: loss_dict has a "predL_r" key that contains the individual prediction # losses # Update running parameter importance estimates in W if isinstance(model, ContinualLearner) and model.si_c>0: for n, p in model.convE.named_parameters(): if p.requires_grad: n = "convE."+n n = n.replace('.', '__') if p.grad is not None: W[n].add_(-p.grad*(p.detach()-p_old[n])) p_old[n] = p.detach().clone() for n, p in model.fcE.named_parameters(): if p.requires_grad: n = "fcE."+n n = n.replace('.', '__') if p.grad is not None: W[n].add_(-p.grad * (p.detach() - p_old[n])) p_old[n] = p.detach().clone() for n, p in model.classifier.named_parameters(): if p.requires_grad: n = "classifier."+n n = n.replace('.', '__') if p.grad is not None: W[n].add_(-p.grad * (p.detach() - p_old[n])) p_old[n] = p.detach().clone() # Fire callbacks (for visualization of training-progress / evaluating performance after each task) for loss_cb in loss_cbs: if loss_cb is not None: loss_cb(progress, batch_index, loss_dict, task=task) for eval_cb in eval_cbs: if eval_cb is not None: eval_cb(model, batch_index, task=task) if model.label=="VAE": for sample_cb in sample_cbs: if sample_cb is not None: sample_cb(model, batch_index, task=task, allowed_classes=None if ( scenario=="domain" ) else list(range(classes_per_task*task))) #---> Train GENERATOR if generator is not None and batch_index <= iters_gen: loss_dict = generator.train_a_batch(x, y=y, x_=x_, y_=y_, scores_=scores_, tasks_=task_used, active_classes=active_classes, rnt=( 1. if task==1 else 1./task ) if rnt is None else rnt, task=task, freeze_convE=freeze_convE, replay_not_hidden=False if Generative else True) # Fire callbacks on each iteration for loss_cb in gen_loss_cbs: if loss_cb is not None: loss_cb(progress_gen, batch_index, loss_dict, task=task) for sample_cb in sample_cbs: if sample_cb is not None: sample_cb(generator, batch_index, task=task, allowed_classes=None if ( scenario=="domain" ) else list(range(classes_per_task*task))) # Close progres-bar(s) progress.close() if generator is not None: progress_gen.close() ##----------> UPON FINISHING EACH TASK... # EWC: estimate Fisher Information matrix (FIM) and update term for quadratic penalty if isinstance(model, ContinualLearner) and model.ewc_lambda>0: # -find allowed classes allowed_classes = list( range(classes_per_task*(task-1), classes_per_task*task) ) if scenario=="task" else (list(range(classes_per_task*task)) if scenario=="class" else None) # -if needed, apply correct task-specific mask if model.mask_dict is not None: model.apply_XdGmask(task=task) # -estimate FI-matrix model.estimate_fisher(train_dataset, allowed_classes=allowed_classes) # SI: calculate and update the normalized path integral if isinstance(model, ContinualLearner) and model.si_c>0: model.update_omega(W, model.epsilon) # REPLAY: update source for replay previous_model = copy.deepcopy(model).eval() if replay_mode=="generative": Generative = True previous_generator = previous_model if feedback else copy.deepcopy(generator).eval() elif replay_mode=='current': Current = True
def get_param_stamp(args, model_name, verbose=True, replay=False, replay_model_name=None): '''Based on the input-arguments, produce a "parameter-stamp".''' # -for task multi_n_stamp = "{n}-{set}".format( n=args.tasks, set=args.scenario) if hasattr(args, "tasks") else "" task_stamp = "{exp}{multi_n}".format(exp=args.experiment, multi_n=multi_n_stamp) if verbose: print(" --> task: " + task_stamp) # -for model model_stamp = model_name if verbose: print(" --> model: " + model_stamp) # -for hyper-parameters hyper_stamp = "{i_e}{num}{epo}-lr{lr}{lrg}{lrt}-b{bsz}-{optim}".format( i_e="e" if args.iters is None else "i", num=args.epochs if args.iters is None else args.iters, lr=args.lr, epo='' if args.epochs == 1 else f'-e{args.epochs}', lrg=("" if args.lr == args.lr_gen else "-lrG{}".format(args.lr_gen)) if (hasattr(args, "lr_gen") and args.replay == 'generative') else "", lrt=("" if args.teacher_lr == args.lr else "-lrT{}".format( args.teacher_lr)) if (utils.checkattr(args, 'use_teacher') and args.replay == 'online') else "", bsz=args.batch, optim=args.optimizer, ) if verbose: print(" --> hyper-params: " + hyper_stamp) # -for EWC / SI if hasattr(args, 'ewc') and ((args.ewc_lambda > 0 and args.ewc) or (args.si_c > 0 and args.si)): ewc_stamp = "EWC{l}-{fi}{o}".format( l=args.ewc_lambda, fi="{}{}".format("N" if args.fisher_n is None else args.fisher_n, "E" if args.emp_fi else ""), o="-O{}".format(args.gamma) if args.online else "", ) if (args.ewc_lambda > 0 and args.ewc) else "" si_stamp = "SI{c}-{eps}".format(c=args.si_c, eps=args.epsilon) if ( args.si_c > 0 and args.si) else "" both = "--" if (args.ewc_lambda > 0 and args.ewc) and (args.si_c > 0 and args.si) else "" if verbose and args.ewc_lambda > 0 and args.ewc: print(" --> EWC: " + ewc_stamp) if verbose and args.si_c > 0 and args.si: print(" --> SI: " + si_stamp) ewc_stamp = "--{}{}{}".format( ewc_stamp, both, si_stamp) if (hasattr(args, 'ewc') and ((args.ewc_lambda > 0 and args.ewc) or (args.si_c > 0 and args.si))) else "" # -for XdG xdg_stamp = "" if (hasattr(args, 'xdg') and args.xdg) and (hasattr(args, "gating_prop") and args.gating_prop > 0): xdg_stamp = "--XdG{}".format(args.gating_prop) if verbose: print(" --> XdG: " + "gating = {}".format(args.gating_prop)) # -for replay if replay: replay_stamp = "{rep}{KD}{agem}{model}{gi}".format( rep=args.replay, KD="-KD{}".format(args.temp) if args.distill else "", agem="-aGEM" if args.agem else "", model="" if (replay_model_name is None) else "-{}".format(replay_model_name), gi="-gi{}".format(args.gen_iters) if (hasattr(args, "gen_iters") and (replay_model_name is not None) and (not args.iters == args.gen_iters)) else "") if args.replay == 'online': distill = '' if hasattr(args, 'use_teacher') and args.use_teacher: distill = '-distill-{}{}'.format( args.distill_type, '-A' if utils.checkattr(args, 'use_augment') else '') teacher_stamp = '{}{}{}{}{}'.format( '' if args.teacher_epochs == 100 else f'-e{args.teacher_epochs}', '' if args.teacher_split == 0.8 else f'-s{args.teacher_split}', '' if args.teacher_loss == 'CE' else f'-{args.teacher_loss}', '' if args.teacher_opt == 'Adam' else f'-{args.teacher_opt}', '' if not args.use_scheduler else '-useSche') embeds = '-embeds' if args.use_embeddings else '' selection = '' if args.triplet_selection == 'HP-HN-1' else f'({args.triplet_selection})' replay_stamp = '{}-b{}{}{}{}{}{}{}'.format( replay_stamp, args.budget, '' if not args.multi_negative else 'MN', f'{distill}', f'{teacher_stamp}', f'{embeds}', f'{selection}', '-addEx' if args.add_exemplars else '') if verbose: print(" --> replay: " + replay_stamp) replay_stamp = "--{}".format(replay_stamp) if replay else "" # -for exemplars / iCaRL exemplar_stamp = "" if args.replay == "exemplars" or (args.add_exemplars or args.use_exemplars) or utils.checkattr( args, 'icarl'): exemplar_opts = "b{}{}{}{}".format( args.budget, "H" if args.herding else "", "N" if args.norm_exemplars else "", "-online" if args.mem_online else "") use = "{}{}{}".format("addEx-" if args.add_exemplars else "", "useEx-" if args.use_exemplars else "", "OTR-" if args.otr_exemplars else "") exemplar_stamp = "--{}{}".format(use, exemplar_opts) if verbose: print(" --> exemplars: " + "{}{}".format(use, exemplar_opts)) # -for binary classification loss binLoss_stamp = "" # if hasattr(args, 'bce') and args.bce: # if not ((hasattr(args, 'otr') and args.otr) or (hasattr(args, 'otr_distill') and args.otr_distill)): # binLoss_stamp = '--BCE_dist' if (args.bce_distill and args.scenario == "class") else '--BCE' # --> combine param_stamp = "{}--{}--{}{}{}{}{}{}{}".format( task_stamp, model_stamp, hyper_stamp, ewc_stamp, xdg_stamp, replay_stamp, exemplar_stamp, binLoss_stamp, "-s{}".format(args.seed) if not args.seed == 0 else "", ) ## Print param-stamp on screen and return if verbose: print(param_stamp) return param_stamp
def define_autoencoder(args, config, device, generator=False, convE=None): # -import required model from models.vae import AutoEncoder # -create model if (hasattr(args, "depth") and args.depth > 0): model = AutoEncoder( image_size=config['size'], image_channels=config['channels'], classes=config['classes'], # -conv-layers conv_type=args.conv_type, depth=args.g_depth if generator and hasattr(args, 'g_depth') else args.depth, start_channels=args.channels, reducing_layers=args.rl, conv_bn=(args.conv_bn == "yes"), conv_nl=args.conv_nl, num_blocks=args.n_blocks, convE=convE, global_pooling=False if generator else checkattr(args, 'gp'), # -fc-layers fc_layers=args.g_fc_lay if generator and hasattr(args, 'g_fc_lay') else args.fc_lay, fc_units=args.g_fc_uni if generator and hasattr(args, 'g_fc_uni') else args.fc_units, h_dim=args.g_h_dim if generator and hasattr(args, 'g_h_dim') else args.h_dim, fc_drop=0 if generator else args.fc_drop, fc_bn=(args.fc_bn == "yes"), fc_nl=args.fc_nl, excit_buffer=True, # -prior prior=args.prior if hasattr(args, "prior") else "standard", n_modes=args.n_modes if hasattr(args, "prior") else 1, per_class=args.per_class if hasattr(args, "prior") else False, z_dim=args.g_z_dim if generator and hasattr(args, 'g_z_dim') else args.z_dim, # -decoder hidden=checkattr(args, 'hidden'), recon_loss=args.recon_loss, network_output="none" if checkattr(args, "normalize") else "sigmoid", deconv_type=args.deconv_type if hasattr(args, "deconv_type") else "standard", dg_gates=utils.checkattr(args, 'dg_gates'), dg_type=args.dg_type if hasattr(args, 'dg_type') else "task", dg_prop=args.dg_prop if hasattr(args, 'dg_prop') else 0., tasks=args.tasks if hasattr(args, 'tasks') else None, scenario=args.scenario if hasattr(args, 'scenario') else None, device=device, # -classifier classifier=False if generator else True, classify_opt=args.classify if hasattr(args, "classify") else "beforeZ", # -training-specific components lamda_rcl=1. if not hasattr(args, 'rcl') else args.rcl, lamda_vl=1. if not hasattr(args, 'vl') else args.vl, lamda_pl=(0. if generator else 1.) if not hasattr(args, 'pl') else args.pl, ).to(device) else: model = AutoEncoder( image_size=config['size'], image_channels=config['channels'], classes=config['classes'], # -fc-layers fc_layers=args.g_fc_lay if generator and hasattr(args, 'g_fc_lay') else args.fc_lay, fc_units=args.g_fc_uni if generator and hasattr(args, 'g_fc_uni') else args.fc_units, h_dim=args.g_h_dim if generator and hasattr(args, 'g_h_dim') else args.h_dim, fc_drop=0 if generator else args.fc_drop, fc_bn=(args.fc_bn == "yes"), fc_nl=args.fc_nl, excit_buffer=True, # -prior prior=args.prior if hasattr(args, "prior") else "standard", n_modes=args.n_modes if hasattr(args, "prior") else 1, per_class=args.per_class if hasattr(args, "prior") else False, z_dim=args.g_z_dim if generator and hasattr(args, 'g_z_dim') else args.z_dim, # -decoder recon_loss=args.recon_loss, network_output="none" if checkattr(args, "normalize") else "sigmoid", deconv_type=args.deconv_type if hasattr(args, "deconv_type") else "standard", dg_gates=utils.checkattr(args, 'dg_gates'), dg_type=args.dg_type if hasattr(args, 'dg_type') else "task", dg_prop=args.dg_prop if hasattr(args, 'dg_prop') else 0., tasks=args.tasks if hasattr(args, 'tasks') else None, scenario=args.scenario if hasattr(args, 'scenario') else None, device=device, # -classifier classifier=False if generator else True, classify_opt=args.classify if hasattr(args, "classify") else "beforeZ", # -training-specific components lamda_rcl=1. if not hasattr(args, 'rcl') else args.rcl, lamda_vl=1. if not hasattr(args, 'vl') else args.vl, lamda_pl=(0. if generator else 1.) if not hasattr(args, 'pl') else args.pl, ).to(device) # -return model return model
def run(args, model_name, shift, slot, verbose=False): # Create plots- and results-directories if needed if not os.path.isdir(args.r_dir): os.mkdir(args.r_dir) if args.pdf and not os.path.isdir(args.p_dir): os.mkdir(args.p_dir) # If only want param-stamp, get it and exit if args.get_stamp: from param_stamp import get_param_stamp_from_args print(get_param_stamp_from_args(args=args)) exit() # Use cuda? cuda = torch.cuda.is_available() and args.cuda device = torch.device("cuda" if cuda else "cpu") # Report whether cuda is used if verbose: print("CUDA is {}used".format("" if cuda else "NOT(!!) ")) # Set random seeds np.random.seed(args.seed) torch.manual_seed(args.seed) if cuda: torch.cuda.manual_seed(args.seed) #-------------------------------------------------------------------------------------------------# #----------------# #----- DATA -----# #----------------# # Prepare data for chosen experiment if verbose: print("\nPreparing the data...") (train_datasets, test_datasets), config, classes_per_task = get_multitask_experiment( name=args.experiment, tasks=args.tasks, slot=args.slot, shift=args.shift, data_dir=args.d_dir, normalize=True if utils.checkattr(args, "normalize") else False, augment=True if utils.checkattr(args, "augment") else False, verbose=verbose, exception=True if args.seed < 10 else False, only_test=(not args.train), max_samples=args.max_samples) #-------------------------------------------------------------------------------------------------# #----------------------# #----- MAIN MODEL -----# #----------------------# # Define main model (i.e., classifier, if requested with feedback connections) if verbose and utils.checkattr( args, "pre_convE") and (hasattr(args, "depth") and args.depth > 0): print("\nDefining the model...") model = define.define_classifier(args=args, config=config, device=device) # Initialize / use pre-trained / freeze model-parameters # - initialize (pre-trained) parameters model = define.init_params(model, args) # - freeze weights of conv-layers? if utils.checkattr(args, "freeze_convE"): for param in model.convE.parameters(): param.requires_grad = False # Define optimizer (only optimize parameters that "requires_grad") model.optim_list = [ { 'params': filter(lambda p: p.requires_grad, model.parameters()), 'lr': args.lr }, ] model.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999)) #-------------------------------------------------------------------------------------------------# #----------------------------------# #----- CL-STRATEGY: EXEMPLARS -----# #----------------------------------# # Store in model whether, how many and in what way to store exemplars if isinstance(model, ExemplarHandler) and (args.use_exemplars or args.replay == "exemplars"): model.memory_budget = args.budget model.herding = args.herding model.norm_exemplars = args.herding #-------------------------------------------------------------------------------------------------# #----------------------------------------------------# #----- CL-STRATEGY: REGULARIZATION / ALLOCATION -----# #----------------------------------------------------# # Elastic Weight Consolidation (EWC) if isinstance(model, ContinualLearner) and utils.checkattr(args, 'ewc'): model.ewc_lambda = args.ewc_lambda if args.ewc else 0 model.fisher_n = args.fisher_n model.online = utils.checkattr(args, 'online') if model.online: model.gamma = args.gamma # Synpatic Intelligence (SI) if isinstance(model, ContinualLearner) and utils.checkattr(args, 'si'): model.si_c = args.si_c if args.si else 0 model.epsilon = args.epsilon # XdG: create for every task a "mask" for each hidden fully connected layer if isinstance(model, ContinualLearner) and utils.checkattr( args, 'xdg') and args.xdg_prop > 0: model.define_XdGmask(gating_prop=args.xdg_prop, n_tasks=args.tasks) #-------------------------------------------------------------------------------------------------# #-------------------------------# #----- CL-STRATEGY: REPLAY -----# #-------------------------------# # Use distillation loss (i.e., soft targets) for replayed data? (and set temperature) if isinstance(model, ContinualLearner) and hasattr( args, 'replay') and not args.replay == "none": model.replay_targets = "soft" if args.distill else "hard" model.KD_temp = args.temp #-------------------------------------------------------------------------------------------------# #---------------------# #----- REPORTING -----# #---------------------# # Get parameter-stamp (and print on screen) if verbose: print("\nParameter-stamp...") param_stamp, reinit_param_stamp = get_param_stamp( args, model.name, verbose=verbose, replay=True if (hasattr(args, 'replay') and not args.replay == "none") else False, ) # Print some model-characteristics on the screen if verbose: # -main model utils.print_model_info(model, title="MAIN MODEL") # Prepare for keeping track of statistics required for metrics (also used for plotting in pdf) if args.pdf or args.metrics: # -define [metrics_dict] to keep track of performance during training for storing & for later plotting in pdf metrics_dict = evaluate.initiate_metrics_dict(n_tasks=args.tasks) # -evaluate randomly initiated model on all tasks & store accuracies in [metrics_dict] (for calculating metrics) if not args.use_exemplars: metrics_dict = evaluate.intial_accuracy( model, test_datasets, metrics_dict, no_task_mask=False, classes_per_task=classes_per_task, test_size=None) else: metrics_dict = None # Prepare for plotting in visdom visdom = None if args.visdom: env_name = "{exp}-{tasks}".format(exp=args.experiment, tasks=args.tasks) replay_statement = "{mode}{b}".format( mode=args.replay, b="" if (args.batch_replay is None or args.batch_replay == args.batch) else "-br{}".format(args.batch_replay), ) if (hasattr(args, "replay") and not args.replay == "none") else "NR" graph_name = "{replay}{syn}{ewc}{xdg}".format( replay=replay_statement, syn="-si{}".format(args.si_c) if utils.checkattr(args, 'si') else "", ewc="-ewc{}{}".format( args.ewc_lambda, "-O{}".format(args.gamma) if utils.checkattr(args, "online") else "") if utils.checkattr( args, 'ewc') else "", xdg="" if (not utils.checkattr(args, 'xdg')) or args.xdg_prop == 0 else "-XdG{}".format(args.xdg_prop), ) visdom = {'env': env_name, 'graph': graph_name} #-------------------------------------------------------------------------------------------------# #---------------------# #----- CALLBACKS -----# #---------------------# # Callbacks for reporting on and visualizing loss solver_loss_cbs = [ cb._solver_loss_cb(log=args.loss_log, visdom=visdom, model=model, iters_per_task=args.iters, tasks=args.tasks, replay=(hasattr(args, "replay") and not args.replay == "none")) ] # Callbacks for reporting and visualizing accuracy # -visdom (i.e., after each [prec_log] eval_cbs = [ cb._eval_cb(log=args.prec_log, test_datasets=test_datasets, visdom=visdom, iters_per_task=args.iters, test_size=args.prec_n, classes_per_task=classes_per_task, with_exemplars=False) ] if (not args.use_exemplars) else [None] #--> during training on a task, evaluation cannot be with exemplars as those are only selected after training # (instead, evaluation for visdom is only done after each task, by including callback-function into [metric_cbs]) # Callbacks for calculating statists required for metrics # -pdf / reporting: summary plots (i.e, only after each task) (when using exemplars, also for visdom) metric_cbs = [ cb._metric_cb(log=args.iters, test_datasets=test_datasets, classes_per_task=classes_per_task, metrics_dict=metrics_dict, iters_per_task=args.iters, with_exemplars=args.use_exemplars), cb._eval_cb(log=args.iters, test_datasets=test_datasets, visdom=visdom, iters_per_task=args.iters, test_size=args.prec_n, classes_per_task=classes_per_task, with_exemplars=True) if args.use_exemplars else None ] #-------------------------------------------------------------------------------------------------# #--------------------# #----- TRAINING -----# #--------------------# if args.train: if verbose: print("\nTraining...") # Train model train_cl( model, train_datasets, model_name=model_name, shift=shift, slot=slot, replay_mode=args.replay if hasattr(args, 'replay') else "none", classes_per_task=classes_per_task, iters=args.iters, args=args, batch_size=args.batch, batch_size_replay=args.batch_replay if hasattr( args, 'batch_replay') else None, eval_cbs=eval_cbs, loss_cbs=solver_loss_cbs, reinit=utils.checkattr(args, 'reinit'), only_last=utils.checkattr(args, 'only_last'), metric_cbs=metric_cbs, use_exemplars=args.use_exemplars, ) # Save trained model(s), if requested if args.save: save_name = "mM-{}".format(param_stamp) if ( not hasattr(args, 'full_stag') or args.full_stag == "none") else "{}-{}".format( model.name, args.full_stag) utils.save_checkpoint(model, args.m_dir, name=save_name, verbose=verbose) else: # Load previously trained model(s) (if goal is to only evaluate previously trained model) if verbose: print("\nLoading parameters of the previously trained models...") load_name = "mM-{}".format(param_stamp) if ( not hasattr(args, 'full_ltag') or args.full_ltag == "none") else "{}-{}".format( model.name, args.full_ltag) utils.load_checkpoint( model, args.m_dir, name=load_name, verbose=verbose, add_si_buffers=(isinstance(model, ContinualLearner) and utils.checkattr(args, 'si'))) # Load previously created metrics-dict file_name = "{}/dict-{}".format(args.r_dir, param_stamp) metrics_dict = utils.load_object(file_name) #-------------------------------------------------------------------------------------------------# #-----------------------------------# #----- EVALUATION of CLASSIFIER-----# #-----------------------------------# if verbose: print("\n\nEVALUATION RESULTS:") # Evaluate precision of final model on full test-set precs = [ evaluate.validate(model, test_datasets[i], verbose=False, test_size=None, task=i + 1, with_exemplars=False, allowed_classes=list( range(classes_per_task * i, classes_per_task * (i + 1)))) for i in range(args.tasks) ] average_precs = sum(precs) / args.tasks # -print on screen if verbose: print("\n Precision on test-set{}:".format( " (softmax classification)" if args.use_exemplars else "")) for i in range(args.tasks): print(" - Task {}: {:.4f}".format(i + 1, precs[i])) print('=> Average precision over all {} tasks: {:.4f}\n'.format( args.tasks, average_precs)) # -with exemplars if args.use_exemplars: precs = [ evaluate.validate(model, test_datasets[i], verbose=False, test_size=None, task=i + 1, with_exemplars=True, allowed_classes=list( range(classes_per_task * i, classes_per_task * (i + 1)))) for i in range(args.tasks) ] average_precs_ex = sum(precs) / args.tasks # -print on screen if verbose: print(" Precision on test-set (classification using exemplars):") for i in range(args.tasks): print(" - Task {}: {:.4f}".format(i + 1, precs[i])) print('=> Average precision over all {} tasks: {:.4f}\n'.format( args.tasks, average_precs_ex)) # If requested, compute metrics '''if args.metrics:
def get_param_stamp(args, model_name, verbose=True, replay=False, replay_model_name=None): '''Based on the input-arguments, produce a "parameter-stamp".''' # -for task multi_n_stamp = "{n}-{set}{of}".format( n=args.tasks, set=args.scenario, of="OL" if checkattr(args, 'only_last') else "") if hasattr( args, "tasks") else "" task_stamp = "{exp}{norm}{aug}{multi_n}".format( exp=args.experiment, norm="-N" if hasattr(args, 'normalize') and args.normalize else "", aug="+" if hasattr(args, "augment") and args.augment else "", multi_n=multi_n_stamp) if verbose: print(" --> task: " + task_stamp) # -for model model_stamp = model_name if verbose: print(" --> model: " + model_stamp) # -for hyper-parameters pre_conv = "" if (checkattr(args, "pre_convE") or checkattr(args, "pre_convD")) and (hasattr(args, 'depth') and args.depth > 0): ltag = "" if not hasattr( args, "convE_ltag") or args.convE_ltag == "none" else "-{}".format( args.convE_ltag) pre_conv = "-pCvE{}".format(ltag) if args.pre_convE else "-pCvD" pre_conv = "-pConv{}".format(ltag) if args.pre_convE and checkattr( args, "pre_convD") else pre_conv freeze_conv = "" if (checkattr(args, "freeze_convD") or checkattr(args, "freeze_convE")) and hasattr( args, 'depth') and args.depth > 0: freeze_conv = "-fCvE" if checkattr(args, "freeze_convE") else "-fCvD" freeze_conv = "-fConv" if checkattr( args, "freeze_convE") and checkattr( args, "freeze_convD") else freeze_conv hyper_stamp = "{i_e}{num}-lr{lr}{lrg}-b{bsz}{pretr}{freeze}{reinit}".format( i_e="e" if args.iters is None else "i", num=args.epochs if args.iters is None else args.iters, lr=args.lr, lrg=("" if args.lr == args.lr_gen else "-lrG{}".format(args.lr_gen)) if (hasattr(args, "lr_gen") and hasattr(args, "replay") and args.replay == "generative" and (not checkattr(args, "feedback"))) else "", bsz=args.batch, pretr=pre_conv, freeze=freeze_conv, reinit="-R" if checkattr(args, 'reinit') else "") if verbose: print(" --> hyper-params: " + hyper_stamp) # -for EWC / SI if (checkattr(args, 'ewc') and args.ewc_lambda > 0) or (checkattr(args, 'si') and args.si_c > 0): ewc_stamp = "EWC{l}-{fi}{o}".format( l=args.ewc_lambda, fi="{}".format("N" if args.fisher_n is None else args.fisher_n), o="-O{}".format(args.gamma) if checkattr(args, 'online') else "", ) if (checkattr(args, 'ewc') and args.ewc_lambda > 0) else "" si_stamp = "SI{c}-{eps}".format(c=args.si_c, eps=args.epsilon) if ( checkattr(args, 'si') and args.si_c > 0) else "" both = "--" if (checkattr(args, 'ewc') and args.ewc_lambda > 0) and ( checkattr(args, 'si') and args.si_c > 0) else "" if verbose and checkattr(args, 'ewc') and args.ewc_lambda > 0: print(" --> EWC: " + ewc_stamp) if verbose and checkattr(args, 'si') and args.si_c > 0: print(" --> SI: " + si_stamp) ewc_stamp = "--{}{}{}".format(ewc_stamp, both, si_stamp) if ( (checkattr(args, 'ewc') and args.ewc_lambda > 0) or (checkattr(args, 'si') and args.si_c > 0)) else "" # -for XdG xdg_stamp = "" if (checkattr(args, "xdg") and args.xdg_prop > 0): xdg_stamp = "--XdG{}".format(args.xdg_prop) if verbose: print(" --> XdG: " + "gating = {}".format(args.xdg_prop)) # -for replay if replay: replay_stamp = "{H}{rep}{bat}{distil}{model}{gi}".format( H="" if not args.replay == "generative" else ("H" if (checkattr(args, "hidden") and hasattr(args, 'depth') and args.depth > 0) else ""), rep="gen" if args.replay == "generative" else args.replay, bat="" if ((not hasattr(args, 'batch_replay')) or (args.batch_replay is None) or args.batch_replay == args.batch) else "-br{}".format(args.batch_replay), distil="-Di{}".format(args.temp) if args.distill else "", model="" if (replay_model_name is None) else "-{}".format(replay_model_name), gi="-gi{}".format(args.g_iters) if (hasattr(args, "g_iters") and (replay_model_name is not None) and (not args.iters == args.g_iters)) else "", ) if verbose: print(" --> replay: " + replay_stamp) replay_stamp = "--{}".format(replay_stamp) if replay else "" # -for choices regarding reconstruction loss if checkattr(args, "feedback"): recon_stamp = "--{}{}".format( "H_" if checkattr(args, "hidden") and hasattr(args, 'depth') and args.depth > 0 else "", args.recon_loss) elif hasattr(args, "replay") and args.replay == "generative": recon_stamp = "--{}".format(args.recon_loss) else: recon_stamp = "" # --> combine param_stamp = "{}--{}--{}{}{}{}{}{}".format( task_stamp, model_stamp, hyper_stamp, ewc_stamp, xdg_stamp, replay_stamp, recon_stamp, "-s{}".format(args.seed) if not args.seed == 0 else "", ) ## Print param-stamp on screen and return if verbose: print(param_stamp) return param_stamp
def train_cl(model, train_datasets, replay_mode="none", rnt=None, classes_per_task=None, iters=2000, batch_size=32, batch_size_replay=None, loss_cbs=list(), eval_cbs=list(), reinit=False, args=None, only_last=False, use_exemplars=False, metric_cbs=list()): '''Train a model (with a "train_a_batch" method) on multiple tasks, with replay-strategy specified by [replay_mode]. [model] <nn.Module> main model to optimize across all tasks [train_datasets] <list> with for each task the training <DataSet> [replay_mode] <str>, choice from "current", "offline" and "none" [classes_per_task] <int>, # classes per task; only 1st task has [classes_per_task]*[first_task_class_boost] classes [rnt] <float>, indicating relative importance of new task (if None, relative to # old tasks) [iters] <int>, # optimization-steps (=batches) per task; 1st task has [first_task_iter_boost] steps more [batch_size_replay] <int>, number of samples to replay per batch [only_last] <bool>, only train on final task / episode [*_cbs] <list> of call-back functions to evaluate training-progress''' # Should convolutional layers be frozen? freeze_convE = (utils.checkattr(args, "freeze_convE") and hasattr(args, "depth") and args.depth > 0) # Use cuda? device = model._device() cuda = model._is_on_cuda() # Set default-values if not specified batch_size_replay = batch_size if batch_size_replay is None else batch_size_replay # Initiate indicators for replay (no replay for 1st task) Exact = Current = Offline_TaskIL = False previous_model = None # Register starting param-values (needed for "intelligent synapses"). if isinstance(model, ContinualLearner) and model.si_c > 0: for n, p in model.named_parameters(): if p.requires_grad: n = n.replace('.', '__') model.register_buffer('{}_SI_prev_task'.format(n), p.detach().clone()) # Loop over all tasks. for task, train_dataset in enumerate(train_datasets, 1): # In offline replay-setting, all tasks so far should be visited separately (i.e., separate data-loader per task) if replay_mode == "offline": Offline_TaskIL = True data_loader = [None] * task train_dataset = train_dataset # Initialize # iters left on data-loader(s) iters_left = 1 if (not Offline_TaskIL) else [1] * task if Exact: iters_left_previous = [1] * (task - 1) data_loader_previous = [None] * (task - 1) # Prepare <dicts> to store running importance estimates and parameter-values before update if isinstance(model, ContinualLearner) and model.si_c > 0: W = {} p_old = {} for n, p in model.named_parameters(): if p.requires_grad: n = n.replace('.', '__') W[n] = p.data.clone().zero_() p_old[n] = p.data.clone() # Find [active_classes] (=classes in current task) active_classes = [ list(range(classes_per_task * i, classes_per_task * (i + 1))) for i in range(task) ] # Reinitialize the model's parameters and the optimizer (if requested) if reinit: from define_models import init_params init_params(model, args) model.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999)) # Define a tqdm progress bar(s) progress = tqdm.tqdm(range(1, iters + 1)) # Loop over all iterations iters_to_use = iters # -if only the final task should be trained on: if only_last and not task == len(train_datasets): iters_to_use = 0 for batch_index in range(1, iters_to_use + 1): # Update # iters left on current data-loader(s) and, if needed, create new one(s) if not Offline_TaskIL: iters_left -= 1 if iters_left == 0: data_loader = iter( utils.get_data_loader(train_dataset, batch_size, cuda=cuda, drop_last=True)) # NOTE: [train_dataset] is training-set of current task # [train_dataset] is training-set of current task with stored exemplars added (if requested) iters_left = len(data_loader) else: # -with "offline replay", there is a separate data-loader for each task batch_size_to_use = batch_size for task_id in range(task): iters_left[task_id] -= 1 if iters_left[task_id] == 0: data_loader[task_id] = iter( utils.get_data_loader(train_datasets[task_id], batch_size_to_use, cuda=cuda, drop_last=True)) iters_left[task_id] = len(data_loader[task_id]) # Update # iters left on data-loader(s) of the previous task(s) and, if needed, create new one(s) if Exact: up_to_task = task - 1 batch_size_replay_pt = int( np.floor( batch_size_replay / up_to_task)) if (up_to_task > 1) else batch_size_replay # -need separate replay for each task for task_id in range(up_to_task): batch_size_to_use = min(batch_size_replay_pt, len(previous_datasets[task_id])) iters_left_previous[task_id] -= 1 if iters_left_previous[task_id] == 0: data_loader_previous[task_id] = iter( utils.get_data_loader(previous_datasets[task_id], batch_size_to_use, cuda=cuda, drop_last=True)) iters_left_previous[task_id] = len( data_loader_previous[task_id]) #-----------------Collect data------------------# #####-----CURRENT BATCH-----##### if not Offline_TaskIL: x, y = next( data_loader) #--> sample training data of current task y = y - classes_per_task * ( task - 1) #--> ITL: adjust y-targets to 'active range' x, y = x.to(device), y.to( device) #--> transfer them to correct device #y = y.expand(1) if len(y.size())==1 else y #--> hack for if batch-size is 1 else: x = y = task_used = None #--> all tasks are "treated as replay" # -sample training data for all tasks so far, move to correct device and store in lists x_, y_ = list(), list() for task_id in range(task): x_temp, y_temp = next(data_loader[task_id]) x_.append(x_temp.to(device)) y_temp = y_temp - ( classes_per_task * task_id ) #--> adjust y-targets to 'active range' if batch_size_to_use == 1: y_temp = torch.tensor([ y_temp ]) #--> correct dimensions if batch-size is 1 y_.append(y_temp.to(device)) #####-----REPLAYED BATCH-----##### if not Offline_TaskIL and not Exact and not Current: x_ = y_ = scores_ = task_used = None #-> if no replay #--------------------------------------------INPUTS----------------------------------------------------# ##-->> Exact Replay <<--## if Exact: # Sample replayed training data, move to correct device and store in lists x_ = list() y_ = list() up_to_task = task - 1 for task_id in range(up_to_task): x_temp, y_temp = next(data_loader_previous[task_id]) x_.append(x_temp.to(device)) # -only keep [y_] if required (as otherwise unnecessary computations will be done) if model.replay_targets == "hard": y_temp = y_temp - ( classes_per_task * task_id ) #-> adjust y-targets to 'active range' y_.append(y_temp.to(device)) else: y_.append(None) # If required, get target scores (i.e, [scores_]) -- using previous model, with no_grad() if (model.replay_targets == "soft") and (previous_model is not None): scores_ = list() for task_id in range(up_to_task): with torch.no_grad(): scores_temp = previous_model(x_[task_id]) scores_temp = scores_temp[:, (classes_per_task * task_id):(classes_per_task * (task_id + 1))] scores_.append(scores_temp) else: scores_ = None ##-->> Current Replay <<--## if Current: x_ = x[:batch_size_replay] #--> use current task inputs task_used = None #--------------------------------------------OUTPUTS----------------------------------------------------# if Current: # Get target scores & possibly labels (i.e., [scores_] / [y_]) -- use previous model, with no_grad() # -[x_] needs to be evaluated according to each previous task, so make list with entry per task scores_ = list() y_ = list() # -if no task-mask and no conditional generator, all scores can be calculated in one go if previous_model.mask_dict is None and not type(x_) == list: with torch.no_grad(): all_scores_ = previous_model.classify(x_) for task_id in range(task - 1): # -if there is a task-mask (i.e., XdG is used), obtain predicted scores for each task separately if previous_model.mask_dict is not None: previous_model.apply_XdGmask(task=task_id + 1) if previous_model.mask_dict is not None or type( x_) == list: with torch.no_grad(): all_scores_ = previous_model.classify( x_[task_id] if type(x_) == list else x_) temp_scores_ = all_scores_[:, (classes_per_task * task_id):(classes_per_task * (task_id + 1))] scores_.append(temp_scores_) # - also get hard target _, temp_y_ = torch.max(temp_scores_, dim=1) y_.append(temp_y_) # -only keep predicted y_/scores_ if required (as otherwise unnecessary computations will be done) y_ = y_ if (model.replay_targets == "hard") else None scores_ = scores_ if (model.replay_targets == "soft") else None #-----------------Train model------------------# # Train the main model with this batch loss_dict = model.train_a_batch(x, y=y, x_=x_, y_=y_, scores_=scores_, tasks_=task_used, active_classes=active_classes, task=task, rnt=(1. if task == 1 else 1. / task) if rnt is None else rnt, freeze_convE=freeze_convE) # Update running parameter importance estimates in W if isinstance(model, ContinualLearner) and model.si_c > 0: for n, p in model.named_parameters(): if p.requires_grad: n = n.replace('.', '__') if p.grad is not None: W[n].add_(-p.grad * (p.detach() - p_old[n])) p_old[n] = p.detach().clone() # Fire callbacks (for visualization of training-progress / evaluating performance after each task) for loss_cb in loss_cbs: if loss_cb is not None: loss_cb(progress, batch_index, loss_dict, task=task) for eval_cb in eval_cbs: if eval_cb is not None: eval_cb(model, batch_index, task=task) # Close progres-bar progress.close() ##----------> UPON FINISHING EACH TASK... # EWC: estimate Fisher Information matrix (FIM) and update term for quadratic penalty if isinstance(model, ContinualLearner) and model.ewc_lambda > 0: # -find allowed classes allowed_classes = list( range(classes_per_task * (task - 1), classes_per_task * task)) # -if needed, apply correct task-specific mask if model.mask_dict is not None: model.apply_XdGmask(task=task) # -estimate FI-matrix model.estimate_fisher(train_dataset, allowed_classes=allowed_classes) # SI: calculate and update the normalized path integral if isinstance(model, ContinualLearner) and model.si_c > 0: model.update_omega(W, model.epsilon) # EXEMPLARS: update exemplar sets if use_exemplars or replay_mode == "exemplars": exemplars_per_class = int( np.floor(model.memory_budget / (classes_per_task * task))) # reduce examplar-sets model.reduce_exemplar_sets(exemplars_per_class) # for each new class trained on, construct examplar-set new_classes = list( range(classes_per_task * (task - 1), classes_per_task * task)) for class_id in new_classes: # create new dataset containing only all examples of this class class_dataset = SubDataset(original_dataset=train_dataset, sub_labels=[class_id]) # based on this dataset, construct new exemplar-set for this class model.construct_exemplar_set(dataset=class_dataset, n=exemplars_per_class) model.compute_means = True # Calculate statistics required for metrics for metric_cb in metric_cbs: if metric_cb is not None: metric_cb(model, iters, task=task) # REPLAY: update source for replay previous_model = copy.deepcopy(model).eval() if replay_mode == 'current': Current = True elif replay_mode in ('exemplars', 'exact'): Exact = True if replay_mode == "exact": previous_datasets = train_datasets[:task] else: previous_datasets = [] for task_id in range(task): previous_datasets.append( ExemplarDataset( model.exemplar_sets[(classes_per_task * task_id):(classes_per_task * (task_id + 1))], target_transform=lambda y, x=classes_per_task * task_id: y + x))
def train_cl(model, train_datasets, replay_mode="none", scenario="task", rnt=None, classes_per_task=None, iters=2000, batch_size=32, batch_size_replay=None, loss_cbs=list(), eval_cbs=list(), sample_cbs=list(), generator=None, gen_iters=0, gen_loss_cbs=list(), feedback=False, reinit=False, args=None, only_last=False): '''Train a model (with a "train_a_batch" method) on multiple tasks, with replay-strategy specified by [replay_mode]. [model] <nn.Module> main model to optimize across all tasks [train_datasets] <list> with for each task the training <DataSet> [replay_mode] <str>, choice from "generative", "current", "offline" and "none" [scenario] <str>, choice from "task", "domain", "class" and "all" [classes_per_task] <int>, # classes per task; only 1st task has [classes_per_task]*[first_task_class_boost] classes [rnt] <float>, indicating relative importance of new task (if None, relative to # old tasks) [iters] <int>, # optimization-steps (=batches) per task; 1st task has [first_task_iter_boost] steps more [batch_size_replay] <int>, number of samples to replay per batch [generator] None or <nn.Module>, if a seperate generative model should be trained (for [gen_iters] per task) [feedback] <bool>, if True and [replay_mode]="generative", the main model is used for generating replay [only_last] <bool>, only train on final task / episode [*_cbs] <list> of call-back functions to evaluate training-progress''' # Should convolutional layers be frozen? freeze_convE = (utils.checkattr(args, "freeze_convE") and hasattr(args, "depth") and args.depth > 0) # Use cuda? device = model._device() cuda = model._is_on_cuda() # Set default-values if not specified batch_size_replay = batch_size if batch_size_replay is None else batch_size_replay # Initiate indicators for replay (no replay for 1st task) Generative = Current = Offline_TaskIL = False previous_model = None # Register starting param-values (needed for "intelligent synapses"). if isinstance(model, ContinualLearner) and model.si_c > 0: for n, p in model.named_parameters(): if p.requires_grad: n = n.replace('.', '__') model.register_buffer('{}_SI_prev_task'.format(n), p.detach().clone()) # Loop over all tasks. for task, train_dataset in enumerate(train_datasets, 1): # If offline replay-setting, create large database of all tasks so far if replay_mode == "offline" and (not scenario == "task"): train_dataset = ConcatDataset(train_datasets[:task]) # -but if "offline"+"task": all tasks so far should be visited separately (i.e., separate data-loader per task) if replay_mode == "offline" and scenario == "task": Offline_TaskIL = True data_loader = [None] * task # Initialize # iters left on data-loader(s) iters_left = 1 if (not Offline_TaskIL) else [1] * task # Prepare <dicts> to store running importance estimates and parameter-values before update if isinstance(model, ContinualLearner) and model.si_c > 0: W = {} p_old = {} for n, p in model.named_parameters(): if p.requires_grad: n = n.replace('.', '__') W[n] = p.data.clone().zero_() p_old[n] = p.data.clone() # Find [active_classes] (=classes in current task) active_classes = None #-> for "domain"- or "all"-scenarios, always all classes are active if scenario == "task": # -for "task"-scenario, create <list> with for all tasks so far a <list> with the active classes active_classes = [ list(range(classes_per_task * i, classes_per_task * (i + 1))) for i in range(task) ] elif scenario == "class": # -for "class"-scenario, create one <list> with active classes of all tasks so far active_classes = list(range(classes_per_task * task)) # Reinitialize the model's parameters (if requested) if reinit: from define_models import init_params init_params(model, args) if generator is not None: init_params(generator, args) # Define a tqdm progress bar(s) iters_main = iters progress = tqdm.tqdm(range(1, iters_main + 1)) if generator is not None: iters_gen = gen_iters progress_gen = tqdm.tqdm(range(1, iters_gen + 1)) # Loop over all iterations iters_to_use = (iters_main if (generator is None) else max(iters_main, iters_gen)) # -if only the final task should be trained on: if only_last and not task == len(train_datasets): iters_to_use = 0 for batch_index in range(1, iters_to_use + 1): # Update # iters left on current data-loader(s) and, if needed, create new one(s) if not Offline_TaskIL: iters_left -= 1 if iters_left == 0: data_loader = iter( utils.get_data_loader(train_dataset, batch_size, cuda=cuda, drop_last=True)) iters_left = len(data_loader) else: # -with "offline replay" in Task-IL scenario, there is a separate data-loader for each task batch_size_to_use = int(np.ceil(batch_size / task)) for task_id in range(task): iters_left[task_id] -= 1 if iters_left[task_id] == 0: data_loader[task_id] = iter( utils.get_data_loader(train_datasets[task_id], batch_size_to_use, cuda=cuda, drop_last=True)) iters_left[task_id] = len(data_loader[task_id]) #-----------------Collect data------------------# #####-----CURRENT BATCH-----##### if not Offline_TaskIL: x, y = next( data_loader) #--> sample training data of current task y = y - classes_per_task * ( task - 1 ) if scenario == "task" else y #--> ITL: adjust y-targets to 'active range' x, y = x.to(device), y.to( device) #--> transfer them to correct device #y = y.expand(1) if len(y.size())==1 else y #--> hack for if batch-size is 1 else: x = y = task_used = None #--> all tasks are "treated as replay" # -sample training data for all tasks so far, move to correct device and store in lists x_, y_ = list(), list() for task_id in range(task): x_temp, y_temp = next(data_loader[task_id]) x_.append(x_temp.to(device)) y_temp = y_temp - ( classes_per_task * task_id ) #--> adjust y-targets to 'active range' if batch_size_to_use == 1: y_temp = torch.tensor([ y_temp ]) #--> correct dimensions if batch-size is 1 y_.append(y_temp.to(device)) #####-----REPLAYED BATCH-----##### if not Offline_TaskIL and not Generative and not Current: x_ = y_ = scores_ = task_used = None #-> if no replay #--------------------------------------------INPUTS----------------------------------------------------# ##-->> Current Replay <<--## if Current: x_ = x[:batch_size_replay] #--> use current task inputs task_used = None ##-->> Generative Replay <<--## if Generative: #---> Only with generative replay, the resulting [x_] will be at the "hidden"-level conditional_gen = True if ( (previous_generator.per_class and previous_generator.prior == "GMM") or utils.checkattr( previous_generator, 'dg_gates')) else False # Sample [x_] if conditional_gen and scenario == "task": # -if a conditional generator is used with task-IL scenario, generate data per previous task x_ = list() task_used = list() for task_id in range(task - 1): allowed_classes = list( range(classes_per_task * task_id, classes_per_task * (task_id + 1))) batch_size_replay_to_use = int( np.ceil(batch_size_replay / (task - 1))) x_temp_ = previous_generator.sample( batch_size_replay_to_use, allowed_classes=allowed_classes, only_x=False) x_.append(x_temp_[0]) task_used.append(x_temp_[2]) else: # -which classes are allowed to be generated? (relevant if conditional generator / decoder-gates) allowed_classes = None if scenario == "domain" else list( range(classes_per_task * (task - 1))) # -which tasks/domains are allowed to be generated? (only relevant if "Domain-IL" with task-gates) allowed_domains = list(range(task - 1)) # -generate inputs representative of previous tasks x_temp_ = previous_generator.sample( batch_size_replay, allowed_classes=allowed_classes, allowed_domains=allowed_domains, only_x=False, ) x_ = x_temp_[0] task_used = x_temp_[2] #--------------------------------------------OUTPUTS----------------------------------------------------# if Generative or Current: # Get target scores & possibly labels (i.e., [scores_] / [y_]) -- use previous model, with no_grad() if scenario in ("domain", "class") and previous_model.mask_dict is None: # -if replay does not need to be evaluated for each task (ie, not Task-IL and no task-specific mask) with torch.no_grad(): all_scores_ = previous_model.classify( x_, not_hidden=False if Generative else True) scores_ = all_scores_[:, :( classes_per_task * (task - 1) )] if ( scenario == "class" ) else all_scores_ # -> when scenario=="class", zero probs will be added in [loss_fn_kd]-function # -also get the 'hard target' _, y_ = torch.max(scores_, dim=1) else: # -[x_] needs to be evaluated according to each previous task, so make list with entry per task scores_ = list() y_ = list() # -if no task-mask and no conditional generator, all scores can be calculated in one go if previous_model.mask_dict is None and not type( x_) == list: with torch.no_grad(): all_scores_ = previous_model.classify( x_, not_hidden=False if Generative else True) for task_id in range(task - 1): # -if there is a task-mask (i.e., XdG is used), obtain predicted scores for each task separately if previous_model.mask_dict is not None: previous_model.apply_XdGmask(task=task_id + 1) if previous_model.mask_dict is not None or type( x_) == list: with torch.no_grad(): all_scores_ = previous_model.classify( x_[task_id] if type(x_) == list else x_, not_hidden=False if Generative else True) if scenario == "domain": # NOTE: if scenario=domain with task-mask, it's of course actually the Task-IL scenario! # this can be used as trick to run the Task-IL scenario with singlehead output layer temp_scores_ = all_scores_ else: temp_scores_ = all_scores_[:, ( classes_per_task * task_id):(classes_per_task * (task_id + 1))] scores_.append(temp_scores_) # - also get hard target _, temp_y_ = torch.max(temp_scores_, dim=1) y_.append(temp_y_) # -only keep predicted y_/scores_ if required (as otherwise unnecessary computations will be done) y_ = y_ if (model.replay_targets == "hard") else None scores_ = scores_ if (model.replay_targets == "soft") else None #-----------------Train model(s)------------------# #---> Train MAIN MODEL if batch_index <= iters_main: # Train the main model with this batch loss_dict = model.train_a_batch( x, y=y, x_=x_, y_=y_, scores_=scores_, tasks_=task_used, active_classes=active_classes, task=task, rnt=(1. if task == 1 else 1. / task) if rnt is None else rnt, freeze_convE=freeze_convE, replay_not_hidden=False if Generative else True) # Update running parameter importance estimates in W if isinstance(model, ContinualLearner) and model.si_c > 0: for n, p in model.convE.named_parameters(): if p.requires_grad: n = "convE." + n n = n.replace('.', '__') if p.grad is not None: W[n].add_(-p.grad * (p.detach() - p_old[n])) p_old[n] = p.detach().clone() for n, p in model.fcE.named_parameters(): if p.requires_grad: n = "fcE." + n n = n.replace('.', '__') if p.grad is not None: W[n].add_(-p.grad * (p.detach() - p_old[n])) p_old[n] = p.detach().clone() for n, p in model.classifier.named_parameters(): if p.requires_grad: n = "classifier." + n n = n.replace('.', '__') if p.grad is not None: W[n].add_(-p.grad * (p.detach() - p_old[n])) p_old[n] = p.detach().clone() # Fire callbacks (for visualization of training-progress / evaluating performance after each task) for loss_cb in loss_cbs: if loss_cb is not None: loss_cb(progress, batch_index, loss_dict, task=task) for eval_cb in eval_cbs: if eval_cb is not None: eval_cb(model, batch_index, task=task) if model.label == "VAE": for sample_cb in sample_cbs: if sample_cb is not None: sample_cb(model, batch_index, task=task, allowed_classes=None if (scenario == "domain") else list( range(classes_per_task * task))) #---> Train GENERATOR if generator is not None and batch_index <= iters_gen: loss_dict = generator.train_a_batch( x, y=y, x_=x_, y_=y_, scores_=scores_, tasks_=task_used, active_classes=active_classes, rnt=(1. if task == 1 else 1. / task) if rnt is None else rnt, task=task, freeze_convE=freeze_convE, replay_not_hidden=False if Generative else True) # Fire callbacks on each iteration for loss_cb in gen_loss_cbs: if loss_cb is not None: loss_cb(progress_gen, batch_index, loss_dict, task=task) for sample_cb in sample_cbs: if sample_cb is not None: sample_cb(generator, batch_index, task=task, allowed_classes=None if (scenario == "domain") else list( range(classes_per_task * task))) # Close progres-bar(s) progress.close() if generator is not None: progress_gen.close() ##----------> UPON FINISHING EACH TASK... # EWC: estimate Fisher Information matrix (FIM) and update term for quadratic penalty if isinstance(model, ContinualLearner) and model.ewc_lambda > 0: # -find allowed classes allowed_classes = list( range(classes_per_task * (task - 1), classes_per_task * task)) if scenario == "task" else ( list(range(classes_per_task * task)) if scenario == "class" else None) # -if needed, apply correct task-specific mask if model.mask_dict is not None: model.apply_XdGmask(task=task) # -estimate FI-matrix model.estimate_fisher(train_dataset, allowed_classes=allowed_classes) # SI: calculate and update the normalized path integral if isinstance(model, ContinualLearner) and model.si_c > 0: model.update_omega(W, model.epsilon) # REPLAY: update source for replay previous_model = copy.deepcopy(model).eval() if replay_mode == "generative": Generative = True previous_generator = previous_model if feedback else copy.deepcopy( generator).eval() elif replay_mode == 'current': Current = True
def run(args): # Use cuda? cuda = torch.cuda.is_available() and args.cuda device = torch.device("cuda" if cuda else "cpu") # Set random seeds np.random.seed(args.seed) torch.manual_seed(args.seed) if cuda: torch.cuda.manual_seed(args.seed) # Report whether cuda is used print("CUDA is {}used".format("" if cuda else "NOT(!!) ")) # Create plots-directory if needed if args.pdf and not os.path.isdir(args.p_dir): os.mkdir(args.p_dir) #-------------------------------------------------------------------------------------------------# #----------------# #----- DATA -----# #----------------# # Prepare data for chosen experiment print("\nPreparing the data...") (trainset, testset), config = get_singletask_experiment( name=args.experiment, data_dir=args.d_dir, verbose=True, normalize = True if utils.checkattr(args, "normalize") else False, augment = True if utils.checkattr(args, "augment") else False, ) # Specify "data-loader" (among others for easy random shuffling and 'batchifying') train_loader = utils.get_data_loader(trainset, batch_size=args.batch, cuda=cuda, drop_last=True) # Determine number of iterations / epochs: iters = args.iters if args.iters else args.epochs*len(train_loader) epochs = ((args.iters-1) // len(train_loader)) + 1 if args.iters else args.epochs #-------------------------------------------------------------------------------------------------# #-----------------# #----- MODEL -----# #-----------------# # Specify model if (utils.checkattr(args, "pre_convE") or utils.checkattr(args, "pre_convD")) and \ (hasattr(args, "depth") and args.depth>0): print("\nDefining the model...") cnn = define.define_classifier(args=args, config=config, device=device) # Initialize (pre-trained) parameters cnn = define.init_params(cnn, args) # - freeze weights of conv-layers? if utils.checkattr(args, "freeze_convE"): for param in cnn.convE.parameters(): param.requires_grad = False cnn.convE.eval() #--> needed to ensure batchnorm-layers also do not change # - freeze weights of representation-learning layers? if utils.checkattr(args, "freeze_full"): for param in cnn.parameters(): param.requires_grad = False for param in cnn.classifier.parameters(): param.requires_grad = True # Set optimizer optim_list = [{'params': filter(lambda p: p.requires_grad, cnn.parameters()), 'lr': args.lr}] cnn.optimizer = torch.optim.Adam(optim_list, betas=(0.9, 0.999)) #-------------------------------------------------------------------------------------------------# #---------------------# #----- REPORTING -----# #---------------------# # Get parameter-stamp print("\nParameter-stamp...") param_stamp = get_param_stamp(args, cnn.name, verbose=True) # Print some model-characteristics on the screen utils.print_model_info(cnn, title="CLASSIFIER") # Define [progress_dicts] to keep track of performance during training for storing and for later plotting in pdf precision_dict = evaluate.initiate_precision_dict(n_tasks=1) # Prepare for plotting in visdom graph_name = cnn.name visdom = None if (not args.visdom) else {'env': args.experiment, 'graph': graph_name} #-------------------------------------------------------------------------------------------------# #---------------------# #----- CALLBACKS -----# #---------------------# # Determine after how many iterations to evaluate the model eval_log = args.prec_log if (args.prec_log is not None) else len(train_loader) # Define callback-functions to evaluate during training # -loss loss_cbs = [cb._solver_loss_cb(log=args.loss_log, visdom=visdom, epochs=epochs)] # -precision eval_cb = cb._eval_cb(log=eval_log, test_datasets=[testset], visdom=visdom, precision_dict=precision_dict) # -visualize extracted representation latent_space_cb = cb._latent_space_cb(log=min(5*eval_log, iters), datasets=[testset], visdom=visdom, sample_size=400) #-------------------------------------------------------------------------------------------------# #--------------------------# #----- (PRE-)TRAINING -----# #--------------------------# # (Pre)train model print("\nTraining...") train.train(cnn, train_loader, iters, loss_cbs=loss_cbs, eval_cbs=[eval_cb, latent_space_cb], save_every=1000 if args.save else None, m_dir=args.m_dir, args=args) # Save (pre)trained model if args.save: # -conv-layers save_name = cnn.convE.name if ( not hasattr(args, 'convE_stag') or args.convE_stag=="none" ) else "{}-{}".format(cnn.convE.name, args.convE_stag) utils.save_checkpoint(cnn.convE, args.m_dir, name=save_name) # -full model save_name = cnn.name if ( not hasattr(args, 'full_stag') or args.full_stag=="none" ) else "{}-{}".format(cnn.name, args.full_stag) utils.save_checkpoint(cnn, args.m_dir, name=save_name) #-------------------------------------------------------------------------------------------------# #--------------------# #----- PLOTTING -----# #--------------------# # if requested, generate pdf. if args.pdf: # -open pdf plot_name = "{}/{}.pdf".format(args.p_dir, param_stamp) pp = plt.open_pdf(plot_name) # -Fig1: show some images images, _ = next(iter(train_loader)) #--> get a mini-batch of random training images plt.plot_images_from_tensor(images, pp, title="example input images", config=config) # -Fig2: precision figure = plt.plot_lines(precision_dict["all_tasks"], x_axes=precision_dict["x_iteration"], line_names=['ave precision'], xlabel="Iterations", ylabel="Test accuracy") pp.savefig(figure) # -close pdf pp.close() # -print name of generated plot on screen print("\nGenerated plot: {}\n".format(plot_name))
def get_param_stamp(args, model_name, verbose=True, replay=False): '''Based on the input-arguments, produce a "parameter-stamp".''' # -for task multi_n_stamp = "{n}{of}".format(n=args.tasks, of="OL" if checkattr(args, 'only_last') else "") if hasattr(args, "tasks") else "" task_stamp = "{exp}{norm}{aug}{multi_n}{max_n}".format( exp=args.experiment, norm="-N" if hasattr(args, 'normalize') and args.normalize else "", aug="+" if hasattr(args, "augment") and args.augment else "", multi_n=multi_n_stamp, max_n="" if (not args.experiment == "CIFAR100") or args.max_samples is None else "-max{}".format(args.max_samples)) if verbose: print(" --> task: " + task_stamp) # -for model model_stamp = model_name if verbose: print(" --> model: " + model_stamp) # -for hyper-parameters pre_conv = "" if checkattr(args, "pre_convE") and (hasattr(args, 'depth') and args.depth > 0): ltag = "" if not hasattr( args, "convE_ltag") or args.convE_ltag == "none" else "-{}".format( args.convE_ltag) pre_conv = "-pCvE{}".format(ltag) freeze_conv = "-fCvE" if (checkattr(args, "freeze_convE") and hasattr( args, 'depth') and args.depth > 0) else "" hyper_stamp = "{i_e}{num}-lr{lr}-b{bsz}{pretr}{freeze}{reinit}".format( i_e="e" if args.iters is None else "i", num=args.epochs if args.iters is None else args.iters, lr=args.lr, bsz=args.batch, pretr=pre_conv, freeze=freeze_conv, reinit="-R" if checkattr(args, 'reinit') else "") if verbose: print(" --> hyper-params: " + hyper_stamp) # -for EWC / SI if (checkattr(args, 'ewc') and args.ewc_lambda > 0) or (checkattr(args, 'si') and args.si_c > 0): ewc_stamp = "EWC{l}-{fi}{o}".format( l=args.ewc_lambda, fi="{}".format("N" if args.fisher_n is None else args.fisher_n), o="-O{}".format(args.gamma) if checkattr(args, 'online') else "", ) if (checkattr(args, 'ewc') and args.ewc_lambda > 0) else "" si_stamp = "SI{c}-{eps}".format(c=args.si_c, eps=args.epsilon) if ( checkattr(args, 'si') and args.si_c > 0) else "" both = "--" if (checkattr(args, 'ewc') and args.ewc_lambda > 0) and ( checkattr(args, 'si') and args.si_c > 0) else "" if verbose and checkattr(args, 'ewc') and args.ewc_lambda > 0: print(" --> EWC: " + ewc_stamp) if verbose and checkattr(args, 'si') and args.si_c > 0: print(" --> SI: " + si_stamp) ewc_stamp = "--{}{}{}".format(ewc_stamp, both, si_stamp) if ( (checkattr(args, 'ewc') and args.ewc_lambda > 0) or (checkattr(args, 'si') and args.si_c > 0)) else "" # -for XdG xdg_stamp = "" if (checkattr(args, "xdg") and args.xdg_prop > 0): xdg_stamp = "--XdG{}".format(args.xdg_prop) if verbose: print(" --> XdG: " + "gating = {}".format(args.xdg_prop)) # -for replay if replay: replay_stamp = "{rep}{bat}{agem}{distil}".format( rep=args.replay, bat="" if ((not hasattr(args, 'batch_replay')) or (args.batch_replay is None) or args.batch_replay == args.batch) else "-br{}".format(args.batch_replay), agem="-aGEM" if args.agem else "", distil="-Di{}".format(args.temp) if args.distill else "", ) if verbose: print(" --> replay: " + replay_stamp) replay_stamp = "--{}".format(replay_stamp) if replay else "" # -for exemplars exemplar_stamp = "" if checkattr(args, 'use_exemplars') or (hasattr(args, 'replay') and args.replay == "exemplars"): exemplar_opts = "b{}{}".format(args.budget, "H" if args.herding else "") use = "useEx-" if args.use_exemplars else "" exemplar_stamp = "--{}{}".format(use, exemplar_opts) if verbose: print(" --> exemplars: " + "{}{}".format(use, exemplar_opts)) # --> combine param_stamp = "{}--{}--{}{}{}{}{}{}".format( task_stamp, model_stamp, hyper_stamp, ewc_stamp, xdg_stamp, replay_stamp, exemplar_stamp, "-s{}".format(args.seed) if not args.seed == 0 else "", ) reinit_param_stamp = "{}--{}--{}{}{}".format( task_stamp, model_stamp, hyper_stamp, "-R" if not checkattr(args, 'reinit') else "", "-s{}".format(args.seed) if not args.seed == 0 else "", ) ## Print param-stamp on screen and return if verbose: print(param_stamp) return param_stamp, reinit_param_stamp
def run(args, verbose=False): # Create plots- and results-directories, if needed if not os.path.isdir(args.r_dir): os.mkdir(args.r_dir) if not os.path.isdir(args.p_dir): os.mkdir(args.p_dir) # If only want param-stamp, get it printed to screen and exit if utils.checkattr(args, "get_stamp"): print(get_param_stamp_from_args(args=args)) exit() # Use cuda? cuda = torch.cuda.is_available() and args.cuda device = torch.device("cuda" if cuda else "cpu") # Set random seeds random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if cuda: torch.cuda.manual_seed(args.seed) #-------------------------------------------------------------------------------------------------# #-----------------------# #----- DATA-STREAM -----# #-----------------------# # Find number of classes per task if args.experiment=="splitMNIST" and args.tasks>10: raise ValueError("Experiment 'splitMNIST' cannot have more than 10 tasks!") classes_per_task = 10 if args.experiment=="permMNIST" else int(np.floor(10/args.tasks)) # Print information on data-stream to screen if verbose: print("\nPreparing the data-stream...") print(" --> {}-incremental learning".format(args.scenario)) ti = "{} classes".format(args.tasks*classes_per_task) if args.stream=="random" and args.scenario=="class" else ( "{} tasks, with {} classes each".format(args.tasks, classes_per_task) ) print(" --> {} data stream: {}\n".format(args.stream, ti)) # Set up the stream of labels (i.e., classes, domain or tasks) to use if args.stream=="task-based": labels_per_batch = True if ((not args.scenario=="class") or classes_per_task==1) else False # -in Task- & Domain-IL scenario, each label is always for entire batch # -in Class-IL scenario, each label is always just for single observation # (but if there is just 1 class per task, setting `label-per_batch` to ``True`` is more efficient) label_stream = TaskBasedStream( n_tasks=args.tasks, iters_per_task=args.iters if labels_per_batch else args.iters*args.batch, labels_per_task=classes_per_task if args.scenario=="class" else 1 ) elif args.stream=="random": label_stream = RandomStream(labels=args.tasks*classes_per_task if args.scenario=="class" else args.tasks) else: raise NotImplementedError("Stream type '{}' not currently implemented.".format(args.stream)) # Load the data-sets (train_datasets, test_datasets), config, labels_per_task = prepare_datasets( name=args.experiment, n_labels=label_stream.n_labels, classes=(args.scenario=="class"), classes_per_task=classes_per_task, dir=args.d_dir, exception=(args.seed<10) ) # Set up the data-stream to be presented to the network data_stream = DataStream( train_datasets, label_stream, batch=args.batch, return_task=(args.scenario=="task"), per_batch=labels_per_batch if (args.stream=="task-based") else args.labels_per_batch, ) #-------------------------------------------------------------------------------------------------# #-----------------# #----- MODEL -----# #-----------------# # Define model # -how many units in the softmax output layer? (e.g., multi-headed or not?) softmax_classes = label_stream.n_labels if args.scenario=="class" else ( classes_per_task if (args.scenario=="domain" or args.singlehead) else classes_per_task*label_stream.n_labels ) # -set up model and move to correct device model = Classifier( image_size=config['size'], image_channels=config['channels'], classes=softmax_classes, fc_layers=args.fc_lay, fc_units=args.fc_units, ).to(device) # -if using a multi-headed output layer, set the "label-per-task"-list as attribute of the model model.multi_head = labels_per_task if (args.scenario=="task" and not args.singlehead) else None #-------------------------------------------------------------------------------------------------# #---------------------# #----- OPTIMIZER -----# #---------------------# # Define optimizer (only include parameters that "requires_grad") optim_list = [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'lr': args.lr}] if not args.cs: # Use the chosen 'standard' optimizer if args.optimizer == "sgd": model.optimizer = optim.SGD(optim_list, weight_decay=args.decay) elif args.optimizer=="adam": model.optimizer = optim.Adam(optim_list, betas=(0.9, 0.999), weight_decay=args.decay) else: # Use the "complex synapse"-version of the chosen optimizer if args.optimizer=="sgd": model.optimizer = cs.ComplexSynapse(optim_list, n_beakers=args.beakers, alpha=args.alpha, beta=args.beta, verbose=verbose) elif args.optimizer=="adam": model.optimizer = cs.AdamComplexSynapse(optim_list, betas=(0.9, 0.999), n_beakers=args.beakers, alpha=args.alpha, beta=args.beta, verbose=verbose) #-------------------------------------------------------------------------------------------------# #---------------------# #----- REPORTING -----# #---------------------# # Get parameter-stamp (and print on screen) if verbose: print("\nParameter-stamp...") param_stamp = get_param_stamp(args, model.name, verbose=verbose) # Print some model-characteristics on the screen if verbose: utils.print_model_info(model, title="MAIN MODEL") # Prepare for keeping track of performance during training for storing & for later plotting in pdf metrics_dict = evaluate.initiate_metrics_dict(n_labels=label_stream.n_labels, classes=(args.scenario == "class")) # Prepare for plotting in visdom if args.visdom: env_name = "{exp}-{scenario}".format(exp=args.experiment, scenario=args.scenario) graph_name = "CS" if args.cs else "Normal" visdom = {'env': env_name, 'graph': graph_name} else: visdom = None #-------------------------------------------------------------------------------------------------# #---------------------# #----- CALLBACKS -----# #---------------------# # Callbacks for reporting on and visualizing loss loss_cbs = [ cb.def_loss_cb( log=args.loss_log, visdom=visdom, tasks=label_stream.n_tasks, iters_per_task=args.iters if args.stream=="task-based" else None, task_name="Episode" if args.scenario=="class" else ("Task" if args.scenario=="task" else "Domain") ) ] # Callbacks for reporting and visualizing accuracy eval_cbs = [ cb.def_eval_cb(log=args.eval_log, test_datasets=test_datasets, scenario=args.scenario, iters_per_task=args.iters if args.stream=="task-based" else None, classes_per_task=classes_per_task, metrics_dict=metrics_dict, test_size=args.eval_n, visdom=visdom, provide_task_info=(args.scenario=="task")) ] # -evaluate accuracy before any training for eval_cb in eval_cbs: if eval_cb is not None: eval_cb(model, 0) #-------------------------------------------------------------------------------------------------# #--------------------# #----- TRAINING -----# #--------------------# # Keep track of training-time if args.time: start = time.time() # Train model if verbose: print("\nTraining...") train_stream(model, data_stream, iters=args.iters*args.tasks if args.stream=="task-based" else args.iters, eval_cbs=eval_cbs, loss_cbs=loss_cbs) # Get total training-time in seconds, and write to file and screen if args.time: training_time = time.time() - start time_file = open("{}/time-{}.txt".format(args.r_dir, param_stamp), 'w') time_file.write('{}\n'.format(training_time)) time_file.close() if verbose: print("=> Total training time = {:.1f} seconds\n".format(training_time)) #-------------------------------------------------------------------------------------------------# #----------------------# #----- EVALUATION -----# #----------------------# if verbose: print("\n\nEVALUATION RESULTS:") # Evaluate precision of final model on full test-set precs = [evaluate.validate( model, test_datasets[i], verbose=False, test_size=None, task=i+1 if args.scenario=="task" else None, ) for i in range(len(test_datasets))] average_precs = sum(precs) / len(test_datasets) # -print to screen if verbose: print("\n Precision on test-set:") for i in range(len(test_datasets)): print(" - {} {}: {:.4f}".format(args.scenario, i + 1, precs[i])) print('=> Average precision over all {} {}{}s: {:.4f}\n'.format( len(test_datasets), args.scenario, "e" if args.scenario=="class" else "", average_precs )) #-------------------------------------------------------------------------------------------------# #------------------# #----- OUTPUT -----# #------------------# # Average precision on full test set output_file = open("{}/prec-{}.txt".format(args.r_dir, param_stamp), 'w') output_file.write('{}\n'.format(average_precs)) output_file.close() # -metrics-dict file_name = "{}/dict-{}".format(args.r_dir, param_stamp) utils.save_object(metrics_dict, file_name) #-------------------------------------------------------------------------------------------------# #--------------------# #----- PLOTTING -----# #--------------------# # If requested, generate pdf if args.pdf: # -open pdf plot_name = "{}/{}.pdf".format(args.p_dir, param_stamp) pp = visual_plt.open_pdf(plot_name) # -show metrics reflecting progression during training figure_list = [] #-> create list to store all figures to be plotted # -generate all figures (and store them in [figure_list]) key = "class" if args.scenario=='class' else "task" plot_list = [] for i in range(label_stream.n_labels): plot_list.append(metrics_dict["acc_per_{}".format(key)]["{}_{}".format(key, i+1)]) figure = visual_plt.plot_lines( plot_list, x_axes=metrics_dict["iters"], xlabel="Iterations", ylabel="Accuracy", line_names=['{} {}'.format(args.scenario, i+1) for i in range(label_stream.n_labels)] ) figure_list.append(figure) figure = visual_plt.plot_lines( [metrics_dict["ave_acc"]], x_axes=metrics_dict["iters"], xlabel="Iterations", ylabel="Accuracy", line_names=['average (over all {}{}s)'.format(args.scenario, "e" if args.scenario=="class" else "")], ylim=(0,1) ) figure_list.append(figure) figure = visual_plt.plot_lines( [metrics_dict["ave_acc_so_far"]], x_axes=metrics_dict["iters"], xlabel="Iterations", ylabel="Accuracy", line_names=['average (over all {}{}s so far)'.format(args.scenario, "e" if args.scenario=="class" else "")], ylim=(0,1) ) figure_list.append(figure) # -add figures to pdf (and close this pdf). for figure in figure_list: pp.savefig(figure) # -close pdf pp.close() # -print name of generated plot on screen if verbose: print("\nGenerated plot: {}\n".format(plot_name))
def run(args, verbose=False): # Create plots- and results-directories if needed if not os.path.isdir(args.r_dir): os.mkdir(args.r_dir) if args.pdf and not os.path.isdir(args.p_dir): os.mkdir(args.p_dir) # If only want param-stamp, get it and exit if args.get_stamp: from param_stamp import get_param_stamp_from_args print(get_param_stamp_from_args(args=args)) exit() # Use cuda? cuda = torch.cuda.is_available() and args.cuda device = torch.device("cuda" if cuda else "cpu") # Report whether cuda is used if verbose: print("CUDA is {}used".format("" if cuda else "NOT(!!) ")) # Set random seeds np.random.seed(args.seed) torch.manual_seed(args.seed) if cuda: torch.cuda.manual_seed(args.seed) #-------------------------------------------------------------------------------------------------# #----------------# #----- DATA -----# #----------------# # Prepare data for chosen experiment if verbose: print("\nPreparing the data...") (train_datasets, test_datasets), config, classes_per_task = get_multitask_experiment( name=args.experiment, tasks=args.tasks, data_dir=args.d_dir, normalize=True if utils.checkattr(args, "normalize") else False, augment=True if utils.checkattr(args, "augment") else False, verbose=verbose, exception=True if args.seed < 10 else False, only_test=(not args.train), max_samples=args.max_samples) #-------------------------------------------------------------------------------------------------# #----------------------# #----- MAIN MODEL -----# #----------------------# # Define main model (i.e., classifier, if requested with feedback connections) if verbose and utils.checkattr( args, "pre_convE") and (hasattr(args, "depth") and args.depth > 0): print("\nDefining the model...") model = define.define_classifier(args=args, config=config, device=device) # Initialize / use pre-trained / freeze model-parameters # - initialize (pre-trained) parameters model = define.init_params(model, args) # - freeze weights of conv-layers? if utils.checkattr(args, "freeze_convE"): for param in model.convE.parameters(): param.requires_grad = False # Define optimizer (only optimize parameters that "requires_grad") model.optim_list = [ { 'params': filter(lambda p: p.requires_grad, model.parameters()), 'lr': args.lr }, ] model.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999)) #-------------------------------------------------------------------------------------------------# #----------------------------------# #----- CL-STRATEGY: EXEMPLARS -----# #----------------------------------# # Store in model whether, how many and in what way to store exemplars if isinstance(model, ExemplarHandler) and (args.use_exemplars or args.replay == "exemplars"): model.memory_budget = args.budget model.herding = args.herding model.norm_exemplars = args.herding #-------------------------------------------------------------------------------------------------# #----------------------------------------------------# #----- CL-STRATEGY: REGULARIZATION / ALLOCATION -----# #----------------------------------------------------# # Elastic Weight Consolidation (EWC) if isinstance(model, ContinualLearner) and utils.checkattr(args, 'ewc'): model.ewc_lambda = args.ewc_lambda if args.ewc else 0 model.fisher_n = args.fisher_n model.online = utils.checkattr(args, 'online') if model.online: model.gamma = args.gamma # Synpatic Intelligence (SI) if isinstance(model, ContinualLearner) and utils.checkattr(args, 'si'): model.si_c = args.si_c if args.si else 0 model.epsilon = args.epsilon # XdG: create for every task a "mask" for each hidden fully connected layer if isinstance(model, ContinualLearner) and utils.checkattr( args, 'xdg') and args.xdg_prop > 0: model.define_XdGmask(gating_prop=args.xdg_prop, n_tasks=args.tasks) #-------------------------------------------------------------------------------------------------# #-------------------------------# #----- CL-STRATEGY: REPLAY -----# #-------------------------------# # Use distillation loss (i.e., soft targets) for replayed data? (and set temperature) if isinstance(model, ContinualLearner) and hasattr( args, 'replay') and not args.replay == "none": model.replay_targets = "soft" if args.distill else "hard" model.KD_temp = args.temp #-------------------------------------------------------------------------------------------------# #---------------------# #----- REPORTING -----# #---------------------# # Get parameter-stamp (and print on screen) if verbose: print("\nParameter-stamp...") param_stamp, reinit_param_stamp = get_param_stamp( args, model.name, verbose=verbose, replay=True if (hasattr(args, 'replay') and not args.replay == "none") else False, ) # Print some model-characteristics on the screen if verbose: # -main model utils.print_model_info(model, title="MAIN MODEL") # Prepare for keeping track of statistics required for metrics (also used for plotting in pdf) if args.pdf or args.metrics: # -define [metrics_dict] to keep track of performance during training for storing & for later plotting in pdf metrics_dict = evaluate.initiate_metrics_dict(n_tasks=args.tasks) # -evaluate randomly initiated model on all tasks & store accuracies in [metrics_dict] (for calculating metrics) if not args.use_exemplars: metrics_dict = evaluate.intial_accuracy( model, test_datasets, metrics_dict, no_task_mask=False, classes_per_task=classes_per_task, test_size=None) else: metrics_dict = None # Prepare for plotting in visdom visdom = None if args.visdom: env_name = "{exp}-{tasks}".format(exp=args.experiment, tasks=args.tasks) replay_statement = "{mode}{b}".format( mode=args.replay, b="" if (args.batch_replay is None or args.batch_replay == args.batch) else "-br{}".format(args.batch_replay), ) if (hasattr(args, "replay") and not args.replay == "none") else "NR" graph_name = "{replay}{syn}{ewc}{xdg}".format( replay=replay_statement, syn="-si{}".format(args.si_c) if utils.checkattr(args, 'si') else "", ewc="-ewc{}{}".format( args.ewc_lambda, "-O{}".format(args.gamma) if utils.checkattr(args, "online") else "") if utils.checkattr( args, 'ewc') else "", xdg="" if (not utils.checkattr(args, 'xdg')) or args.xdg_prop == 0 else "-XdG{}".format(args.xdg_prop), ) visdom = {'env': env_name, 'graph': graph_name} #-------------------------------------------------------------------------------------------------# #---------------------# #----- CALLBACKS -----# #---------------------# # Callbacks for reporting on and visualizing loss solver_loss_cbs = [ cb._solver_loss_cb(log=args.loss_log, visdom=visdom, model=model, iters_per_task=args.iters, tasks=args.tasks, replay=(hasattr(args, "replay") and not args.replay == "none")) ] # Callbacks for reporting and visualizing accuracy # -visdom (i.e., after each [prec_log] eval_cbs = [ cb._eval_cb(log=args.prec_log, test_datasets=test_datasets, visdom=visdom, iters_per_task=args.iters, test_size=args.prec_n, classes_per_task=classes_per_task, with_exemplars=False) ] if (not args.use_exemplars) else [None] #--> during training on a task, evaluation cannot be with exemplars as those are only selected after training # (instead, evaluation for visdom is only done after each task, by including callback-function into [metric_cbs]) # Callbacks for calculating statists required for metrics # -pdf / reporting: summary plots (i.e, only after each task) (when using exemplars, also for visdom) metric_cbs = [ cb._metric_cb(log=args.iters, test_datasets=test_datasets, classes_per_task=classes_per_task, metrics_dict=metrics_dict, iters_per_task=args.iters, with_exemplars=args.use_exemplars), cb._eval_cb(log=args.iters, test_datasets=test_datasets, visdom=visdom, iters_per_task=args.iters, test_size=args.prec_n, classes_per_task=classes_per_task, with_exemplars=True) if args.use_exemplars else None ] #-------------------------------------------------------------------------------------------------# #--------------------# #----- TRAINING -----# #--------------------# if args.train: if verbose: print("\nTraining...") # Train model train_cl( model, train_datasets, replay_mode=args.replay if hasattr(args, 'replay') else "none", classes_per_task=classes_per_task, iters=args.iters, args=args, batch_size=args.batch, batch_size_replay=args.batch_replay if hasattr( args, 'batch_replay') else None, eval_cbs=eval_cbs, loss_cbs=solver_loss_cbs, reinit=utils.checkattr(args, 'reinit'), only_last=utils.checkattr(args, 'only_last'), metric_cbs=metric_cbs, use_exemplars=args.use_exemplars, ) # Save trained model(s), if requested if args.save: save_name = "mM-{}".format(param_stamp) if ( not hasattr(args, 'full_stag') or args.full_stag == "none") else "{}-{}".format( model.name, args.full_stag) utils.save_checkpoint(model, args.m_dir, name=save_name, verbose=verbose) else: # Load previously trained model(s) (if goal is to only evaluate previously trained model) if verbose: print("\nLoading parameters of the previously trained models...") load_name = "mM-{}".format(param_stamp) if ( not hasattr(args, 'full_ltag') or args.full_ltag == "none") else "{}-{}".format( model.name, args.full_ltag) utils.load_checkpoint( model, args.m_dir, name=load_name, verbose=verbose, add_si_buffers=(isinstance(model, ContinualLearner) and utils.checkattr(args, 'si'))) # Load previously created metrics-dict file_name = "{}/dict-{}".format(args.r_dir, param_stamp) metrics_dict = utils.load_object(file_name) #-------------------------------------------------------------------------------------------------# #-----------------------------------# #----- EVALUATION of CLASSIFIER-----# #-----------------------------------# if verbose: print("\n\nEVALUATION RESULTS:") # Evaluate precision of final model on full test-set precs = [ evaluate.validate(model, test_datasets[i], verbose=False, test_size=None, task=i + 1, with_exemplars=False, allowed_classes=list( range(classes_per_task * i, classes_per_task * (i + 1)))) for i in range(args.tasks) ] average_precs = sum(precs) / args.tasks # -print on screen if verbose: print("\n Precision on test-set{}:".format( " (softmax classification)" if args.use_exemplars else "")) for i in range(args.tasks): print(" - Task {}: {:.4f}".format(i + 1, precs[i])) print('=> Average precision over all {} tasks: {:.4f}\n'.format( args.tasks, average_precs)) # -with exemplars if args.use_exemplars: precs = [ evaluate.validate(model, test_datasets[i], verbose=False, test_size=None, task=i + 1, with_exemplars=True, allowed_classes=list( range(classes_per_task * i, classes_per_task * (i + 1)))) for i in range(args.tasks) ] average_precs_ex = sum(precs) / args.tasks # -print on screen if verbose: print(" Precision on test-set (classification using exemplars):") for i in range(args.tasks): print(" - Task {}: {:.4f}".format(i + 1, precs[i])) print('=> Average precision over all {} tasks: {:.4f}\n'.format( args.tasks, average_precs_ex)) # If requested, compute metrics if args.metrics: # Load accuracy matrix of "reinit"-experiment (i.e., each task's accuracy when only trained on that task) if not utils.checkattr(args, 'reinit'): file_name = "{}/dict-{}".format(args.r_dir, reinit_param_stamp) if not os.path.isfile("{}.pkl".format(file_name)): raise FileNotFoundError( "Need to run the correct 'reinit' experiment (with --metrics) first!!" ) reinit_metrics_dict = utils.load_object(file_name) # Accuracy matrix R = pd.DataFrame( data=metrics_dict['acc per task'], index=['after task {}'.format(i + 1) for i in range(args.tasks)]) R = R[["task {}".format(task_id + 1) for task_id in range(args.tasks)]] R.loc['at start'] = metrics_dict['initial acc per task'] if ( not args.use_exemplars) else ['NA' for _ in range(args.tasks)] if not utils.checkattr(args, 'reinit'): R.loc['only trained on itself'] = [ reinit_metrics_dict['acc per task']['task {}'.format( task_id + 1)][task_id] for task_id in range(args.tasks) ] R = R.reindex( ['at start'] + ['after task {}'.format(i + 1) for i in range(args.tasks)] + ['only trained on itself']) BWTs = [(R.loc['after task {}'.format(args.tasks), 'task {}'.format(i + 1)] - \ R.loc['after task {}'.format(i + 1), 'task {}'.format(i + 1)]) for i in range(args.tasks - 1)] FWTs = [ 0. if args.use_exemplars else (R.loc['after task {}'.format(i + 1), 'task {}'.format(i + 2)] - R.loc['at start', 'task {}'.format(i + 2)]) for i in range(args.tasks - 1) ] forgetting = [] for i in range(args.tasks - 1): forgetting.append( max(R.iloc[1:args.tasks, i]) - R.iloc[args.tasks, i]) R.loc['FWT (per task)'] = ['NA'] + FWTs R.loc['BWT (per task)'] = BWTs + ['NA'] R.loc['F (per task)'] = forgetting + ['NA'] BWT = sum(BWTs) / (args.tasks - 1) F = sum(forgetting) / (args.tasks - 1) FWT = sum(FWTs) / (args.tasks - 1) metrics_dict['BWT'] = BWT metrics_dict['F'] = F metrics_dict['FWT'] = FWT # -Vogelstein et al's measures of transfer efficiency if not utils.checkattr(args, 'reinit'): TEs = [((1 - R.loc['only trained on itself', 'task {}'.format(task_id + 1)]) / (1 - R.loc['after task {}'.format(args.tasks), 'task {}'.format(task_id + 1)])) for task_id in range(args.tasks)] BTEs = [((1 - R.loc['after task {}'.format(task_id + 1), 'task {}'.format(task_id + 1)]) / (1 - R.loc['after task {}'.format(args.tasks), 'task {}'.format(task_id + 1)])) for task_id in range(args.tasks)] FTEs = [((1 - R.loc['only trained on itself', 'task {}'.format(task_id + 1)]) / (1 - R.loc['after task {}'.format(task_id + 1), 'task {}'.format(task_id + 1)])) for task_id in range(args.tasks)] # -TEs and BTEs after each task TEs_all = [] BTEs_all = [] for after_task_id in range(args.tasks): TEs_all.append([ ((1 - R.loc['only trained on itself', 'task {}'.format(task_id + 1)]) / (1 - R.loc['after task {}'.format(after_task_id + 1), 'task {}'.format(task_id + 1)])) for task_id in range(after_task_id + 1) ]) BTEs_all.append([ ((1 - R.loc['after task {}'.format(task_id + 1), 'task {}'.format(task_id + 1)]) / (1 - R.loc['after task {}'.format(after_task_id + 1), 'task {}'.format(task_id + 1)])) for task_id in range(after_task_id + 1) ]) R.loc['TEs (per task, after all 10 tasks)'] = TEs for after_task_id in range(args.tasks): R.loc['TEs (per task, after {} tasks)'.format( after_task_id + 1)] = TEs_all[after_task_id] + ['NA'] * (args.tasks - after_task_id - 1) R.loc['BTEs (per task, after all 10 tasks)'] = BTEs for after_task_id in range(args.tasks): R.loc['BTEs (per task, after {} tasks)'.format( after_task_id + 1)] = BTEs_all[after_task_id] + ['NA'] * ( args.tasks - after_task_id - 1) R.loc['FTEs (per task)'] = FTEs metrics_dict['R'] = R # -print on screen if verbose: print("Accuracy matrix") print(R) print("\nFWT = {:.4f}".format(FWT)) print("BWT = {:.4f}".format(BWT)) print(" F = {:.4f}\n\n".format(F)) #-------------------------------------------------------------------------------------------------# #------------------# #----- OUTPUT -----# #------------------# # Average precision on full test set output_file = open("{}/prec-{}.txt".format(args.r_dir, param_stamp), 'w') output_file.write('{}\n'.format( average_precs_ex if args.use_exemplars else average_precs)) output_file.close() # -metrics-dict if args.metrics: file_name = "{}/dict-{}".format(args.r_dir, param_stamp) utils.save_object(metrics_dict, file_name) #-------------------------------------------------------------------------------------------------# #--------------------# #----- PLOTTING -----# #--------------------# # If requested, generate pdf if args.pdf: # -open pdf plot_name = "{}/{}.pdf".format(args.p_dir, param_stamp) pp = evaluate.visual.plt.open_pdf(plot_name) # -plot TEs if not utils.checkattr(args, 'reinit'): BTEs = [] for task_id in range(args.tasks): BTEs.append([ R.loc['BTEs (per task, after {} tasks)'. format(after_task_id + 1), 'task {}'.format(task_id + 1)] for after_task_id in range(task_id, args.tasks) ]) figure = visual_plt.plot_TEs([FTEs], [BTEs], [TEs], ["test"]) pp.savefig(figure) # -show metrics reflecting progression during training if args.train and (not utils.checkattr(args, 'only_last')): # -create list to store all figures to be plotted. figure_list = [] # -generate all figures (and store them in [figure_list]) key = "acc per task" plot_list = [] for i in range(args.tasks): plot_list.append(metrics_dict[key]["task {}".format(i + 1)]) figure = visual_plt.plot_lines(plot_list, x_axes=metrics_dict["x_task"], line_names=[ 'task {}'.format(i + 1) for i in range(args.tasks) ]) figure_list.append(figure) figure = visual_plt.plot_lines( [metrics_dict["average"]], x_axes=metrics_dict["x_task"], line_names=['average all tasks so far']) figure_list.append(figure) # -add figures to pdf for figure in figure_list: pp.savefig(figure) # -close pdf pp.close() # -print name of generated plot on screen if verbose: print("\nGenerated plot: {}\n".format(plot_name))