def run(config, train_system=True, only_train_replay=False, train_tandem=True): """ Method to start training MNIST replay model. Depending on the configurations, here we control the creation and training of the different replay modules with their corresponding hypernetworks. Args: config: The command line arguments. train_system: (optional) Set to false if we want this function only to create config, networks and data_handlers for future training. See :func:`mnist.train_splitMNIST.run` for a use case. only_train_replay: (optional) If this script will only be used to train a replay model. Normally, we use this script in tandem with an additional classifier that uses this replay model to replay old tasks data. train_tandem: (optional) If we will use this script to train in tandem i.e. in an alternating fashion with a classifier. Returns: (tuple): Tuple containing: (....): See docstring of function :func:`train`. """ # if we want to train a classifier on single classes then we need a single # class replay method. This need not be the case otherwise i.e. we can # have a single class replay method but train our classifier on the # replay data (build out of multiple replayed conidtions) and the current # data at once. # single class replay only implemented for splitMNIST if config.single_class_replay: assert (config.experiment == "splitMNIST") if config.num_tasks > 100 and config.cl_scenario != 1: print("Attention: Replay model not tested for num tasks > 100") ### Setup environment device, writer = train_utils._setup_environment(config) ### Create tasks for split MNIST if config.single_class_replay: steps = 1 else: steps = 2 ### Create tasks for split MNIST if train_system == False and config.upper_bound == False: dhandlers = None else: dhandlers = train_utils._generate_tasks(config, steps) ### Generate networks. if train_system == False: enc, dec, d_hnet = None, None, None else: if config.rp_beta > 0: create_rp_hnet = True else: create_rp_hnet = False enc, dec, d_hnet = train_utils_replay.generate_replay_networks(config, dhandlers, device, create_rp_hnet, only_train_replay=only_train_replay) ### Generate task prioirs for latent space. priors = [] test_z = [] vae_conds = [] ### Save some noise vectors for testing for t in range(config.num_embeddings): # if conditional prior create some task priors and save them if config.conditional_prior: mu = torch.zeros((config.latent_dim)).to(device) nn.init.normal_(mu, mean=0, std=1.) mu = torch.stack([mu] * config.batch_size) mu.requires_grad = False priors.append(mu) else: mu = torch.zeros((config.batch_size, config.latent_dim)).to(device) priors.append(None) ### Generate sampler for latent space. eps = torch.randn_like(mu) sample = mu + eps sample.requires_grad = False test_z.append(sample) # if vae has some conditional input, then either save hot-encodings # or some conditions from a gaussian if config.conditional_replay: vae_c = torch.zeros((config.conditional_dim)).to(device) if not config.not_conditional_hot_enc: vae_c[t] = 1 else: nn.init.normal_(vae_c, mean=0, std=1.) vae_c = torch.stack([vae_c] * config.batch_size) vae_c.requires_grad = False vae_conds.append(vae_c) config.test_z = test_z config.priors = priors config.vae_conds = vae_conds if not train_tandem: ### Train the network. train(dhandlers, enc, dec, d_hnet, device, config, writer) ### Test network. test(enc, dec, d_hnet, device, config, writer) return dec, d_hnet, enc, dhandlers, device, writer, config
def run(mode='split'): """ Method to start MNIST experiments. Depending on the configurations, here we control the creation and training of the different (replay) modules for classification or task inference build out of standart neural networks and their corresponding hypernetworks. Args: mode (str): Training mode defines which experiments and default values are loaded. Options are splitMNIST or permutedMNIST: - ``split`` - ``perm`` """ ### Get command line arguments. config = train_args.parse_cmd_arguments(mode=mode) assert (config.experiment == "splitMNIST" or \ config.experiment == "permutedMNIST") if not config.dont_set_default: config = _set_default(config) if config.infer_output_head: assert (config.infer_task_id == True) if config.cl_scenario == 1: assert (config.class_incremental == False) assert (config.single_class_replay == False) if config.infer_with_entropy: assert (config.infer_task_id == True) # single class only implemented for splitMNIST if config.single_class_replay or config.class_incremental: assert (config.experiment == "splitMNIST") # check range of number of tasks assert (config.num_tasks > 0) if config.experiment == "splitMNIST": if config.class_incremental: assert (config.num_tasks <= 10) else: assert (config.num_tasks <= 5) # the following combination is not supported if config.infer_task_id: assert (config.class_incremental == False) # enforce correct cl scenario if config.class_incremental: config.single_class_replay = 1 config.cl_scenario = 3 print("Attention: Cl scenario 3 is enforced!") steps = 1 else: steps = 2 #### Get data handlers dhandlers_class = train_utils._generate_tasks(config, steps) # decide if you want to train a replay model # in the case where you only want a classifier and you know the task id # we only train a classifier + hnet. Upper bound considers the replay case # but you replay real data as if the replayu model would be "perfect". if config.upper_bound or (config.infer_task_id and config.cl_scenario == 1): train_rp = False else: train_rp = True ### Get replay model trained continually with hnet. dec, d_hnet, enc, dhandlers_rp, device, writer, config = \ replay_model(config, train_rp) # if we have a replay model trained, we now train a classifier # that either solves a task directly (HNET+replay) or we train a model # that infers the task from input. ############################### # Train task inference network ############################### if config.infer_task_id and not config.cl_scenario == 1 and \ not config.infer_with_entropy: print("Training task inference model ...") config.trained_replay_model = False config.training_task_infer = True config.training_with_hnet = False ### Generate task inference network. infer_net = train_utils.generate_classifier(config, dhandlers_class, device) ### Train the task inference network. config.during_accs_inference = train_tasks(dhandlers_class, dhandlers_rp, enc, dec, d_hnet, None, device, config, writer, infer_net=infer_net) ### Test network. print("Testing task inference model ...") test(dhandlers_class, None, infer_net, device, config, writer) config.training_with_hnet = True config.trained_replay_model = True else: # if we do not train an inference network we just train a model # that knows it all and not infer_net = None if config.infer_with_entropy: config.trained_replay_model = True else: config.trained_replay_model = False if config.infer_task_id: config.training_with_hnet = True else: config.training_with_hnet = False ################### # Train classifier ################### config.training_task_infer = False print("Training final classifier ...") ### Generate another classifier network. class_nets = train_utils.generate_classifier(config, dhandlers_class, device) ### Train the network. config.during_accs_final = train_tasks(dhandlers_class, dhandlers_rp, enc, dec, d_hnet, class_nets, device, config, writer, infer_net) print("Testing final classifier ...") ### Test network. test(dhandlers_class, class_nets, infer_net, device, config, writer) _save_performance_summary(config) writer.close() print('Program finished successfully.')