def test(enc, dec, d_hnet, device, config, writer, train_iter=None, condition=None): """ Test the MNIST VAE - here we only sample from a fixed noise to compare images qualitatively. One should also keep track of the reconstruction error of e.g. a test set. Args: (....): See docstring of function :func:`train`. train_iter: The current training iteration. condition: Condition (class/task) we are currently training. """ if train_iter is None: print('### Final test run ...') train_iter = config.n_iter else: print('# Testing network before running training step %d ...' % \ train_iter) # if no condition is given, we iterate over all (trained) embeddings if condition is None: condition = config.num_embeddings - 1 # eval all nets enc.eval() dec.eval() if d_hnet is not None: d_hnet.eval() with torch.no_grad(): # iterate over all conditions for m in range(condition + 1): # Get pre training saved noise z = config.test_z[m] reconstructions = sample(dec, d_hnet, config, m, device, z=z) if config.show_plots: fig_real = _plotImages(reconstructions, config) writer.add_figure('test_cond_' + str(m) + '_sampled_after_' + str(condition), fig_real, global_step=train_iter) if train_iter == config.n_iter: writer.add_figure('test_cond_final_' + str(m) + '_sampled_after_' + str(condition), fig_real, global_step=train_iter)
def train_class_one_t(dhandler_class, dhandlers_rp, dec, d_hnet, net, device, config, writer, t): """Train continual learning experiments on MNIST dataset for one task. In this function the main training logic is implemented. After setting the optimizers for the network and hypernetwork if applicable, the training is structured as follows: First, we get the a training batch of the current task. Depending on the learning scenario, we choose output heads and build targets accordingly. Second, if ``t`` is greater than 1, we add a loss term concerning predictions of replayed data. See :func:`get_fake_data_loss` for details. Third, to protect the hypernetwork from forgetting, we add an additional L2 loss term namely the difference between its current output given an embedding and checkpointed targets. Finally, we track some training statistics. Args: (....): See docstring of function :func:`train_tasks`. t: Task id. """ # if cl with task inference we have the classifier empowered with a hnet if config.training_with_hnet: net_hnet = net[1] net = net[0] net.train() net_hnet.train() params_to_regularize = list(net_hnet.theta) optimizer = optim.Adam(params_to_regularize, lr=config.class_lr, betas=(0.9, 0.999)) c_emb_optimizer = optim.Adam([net_hnet.get_task_emb(t)], lr=config.class_lr_emb, betas=(0.9, 0.999)) else: net.train() net_hnet = None optimizer = optim.Adam(net.parameters(), lr=config.class_lr, betas=(0.9, 0.999)) # dont train the replay model if available if dec is not None: dec.eval() if d_hnet is not None: d_hnet.eval() # compute targets if classifier is trained with hnet if t > 0 and config.training_with_hnet: if config.online_target_computation: # Compute targets for the regularizer whenever they are needed. # -> Computationally expensive. targets_C = None prev_theta = [p.detach().clone() for p in net_hnet.theta] prev_task_embs = [p.detach().clone() for p in \ net_hnet.get_task_embs()] else: # Compute targets for the regularizer once and keep them all in # memory -> Memory expensive. targets_C = hreg.get_current_targets(t, net_hnet) prev_theta = None prev_task_embs = None dhandler_class.reset_batch_generator() # make copy of network if t >= 1: net_copy = copy.deepcopy(net) # set training_iterations if epochs are set if config.epochs == -1: training_iterations = config.n_iter else: assert (config.epochs > 0) training_iterations = config.epochs * \ int(np.ceil(dhandler_class.num_train_samples / config.batch_size)) if config.class_incremental: training_iterations = int(training_iterations / config.out_dim) # Whether we will calculate the regularizer. calc_reg = t > 0 and config.class_beta > 0 and config.training_with_hnet # set if we want the reg only computed for a subset of the previous tasks if config.hnet_reg_batch_size != -1: hnet_reg_batch_size = config.hnet_reg_batch_size else: hnet_reg_batch_size = None for i in range(training_iterations): # set optimizer to zero optimizer.zero_grad() if net_hnet is not None: c_emb_optimizer.zero_grad() # Get real data real_batch = dhandler_class.next_train_batch(config.batch_size) X_real = dhandler_class.input_to_torch_tensor(real_batch[0], device, mode='train') T_real = dhandler_class.output_to_torch_tensor(real_batch[1], device, mode='train') if i % 100 == 0 and config.show_plots: fig_real = _plotImages(X_real, config) writer.add_figure('train_class_' + str(t) + '_real', fig_real, global_step=i) ################################################# # Choosing output heads and constructing targets ################################################# # If we train a task inference net or class incremental learning we # we construct a target for every single class/task if config.class_incremental or config.training_task_infer: # in the beginning of training, we look at two output neuron task_out = [0, t + 1] T_real = torch.zeros((config.batch_size, task_out[1])).to(device) T_real[:, task_out[1] - 1] = 1 elif config.cl_scenario == 1 or config.cl_scenario == 2: if config.cl_scenario == 1: # take the task specific output neuron task_out = [t * config.out_dim, t * config.out_dim + config.out_dim] else: # always all output neurons, only one head is used task_out = [0, config.out_dim] else: # The number of output neurons is generic and can grow i.e. we # do not have to know the number of tasks before we start # learning. if not config.infer_output_head: task_out = [0, (t + 1) * config.out_dim] T_real = torch.cat((torch.zeros((config.batch_size, t * config.out_dim)).to(device), T_real), dim=1) # this is a special case where we will infer the task id by another # neural network so we can train on the correct output head direclty # and use the infered output head to compute the prediction else: task_out = [t * config.out_dim, t * config.out_dim + config.out_dim] # compute loss of current data if config.training_with_hnet: weights_c = net_hnet.forward(t) else: weights_c = None Y_hat_logits = net.forward(X_real, weights_c) Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]] if config.soft_targets: soft_label = 0.95 num_classes = T_real.shape[1] soft_targets = torch.where(T_real == 1, torch.Tensor([soft_label]).to(device), torch.Tensor([(1 - soft_label) / (num_classes - 1)]).to(device)) soft_targets = soft_targets.to(device) loss_task = Classifier.softmax_and_cross_entropy(Y_hat_logits, soft_targets) else: loss_task = Classifier.softmax_and_cross_entropy(Y_hat_logits, T_real) ############################ # compute loss for fake data ############################ # Get fake data (of all tasks up until now and merge into list) if t >= 1 and not config.training_with_hnet: fake_loss = get_fake_data_loss(dhandlers_rp, net, dec, d_hnet, device, config, writer, t, i, net_copy) loss_task = (1 - config.l_rew) * loss_task + config.l_rew * fake_loss loss_task.backward(retain_graph=calc_reg, create_graph=calc_reg and \ config.backprop_dt) # compute hypernet loss and fix embedding -> change current embs if calc_reg: if config.no_lookahead: dTheta = None else: dTheta = opstep.calc_delta_theta(optimizer, config.use_sgd_change, lr=config.class_lr, detach_dt=not config.backprop_dt) loss_reg = config.class_beta * hreg.calc_fix_target_reg(net_hnet, t, targets=targets_C, mnet=net, dTheta=dTheta, dTembs=None, prev_theta=prev_theta, prev_task_embs=prev_task_embs, batch_size=hnet_reg_batch_size) loss_reg.backward() # compute backward passloss_task.backward() if not config.dont_train_main_model: optimizer.step() if net_hnet is not None and config.train_class_embeddings: c_emb_optimizer.step() # same stats saving if i % 50 == 0: # compute accuracies for tracking Y_hat_logits = net.forward(X_real, weights_c) Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]] Y_hat = F.softmax(Y_hat_logits, dim=1) classifier_accuracy = Classifier.accuracy(Y_hat, T_real) * 100.0 writer.add_scalar('train/task_%d/class_accuracy' % t, classifier_accuracy, i) writer.add_scalar('train/task_%d/loss_task' % t, loss_task, i) if t >= 1 and not config.training_with_hnet: writer.add_scalar('train/task_%d/fake_loss' % t, fake_loss, i) # plot some gradient statistics if i % 200 == 0: if not config.dont_train_main_model: total_norm = 0 if config.training_with_hnet: params = net_hnet.theta else: params = net.parameters() for p in params: param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 total_norm = total_norm ** (1. / 2) # TODO write gradient histograms? writer.add_scalar('train/task_%d/main_params_grad_norms' % t, total_norm, i) if net_hnet is not None and config.train_class_embeddings: total_norm = 0 for p in [net_hnet.get_task_emb(t)]: param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 total_norm = total_norm ** (1. / 2) writer.add_scalar('train/task_%d/hnet_emb_grad_norms' % t, total_norm, i) if i % 200 == 0: msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \ '(on current training batch).' print(msg.format(i, classifier_accuracy))
def get_fake_data_loss(dhandlers_rp, net, dec, d_hnet, device, config, writer, t, i, net_copy): """ Sample fake data from generator for tasks up to t and compute a loss compared to predictions of a checkpointed network. We must take caution when considering the different learning scenarios and methods and training stages, see detailed comments in the code. In general, we build a batch of replayed data from all previous tasks. Since we do not know the labels of the replayed data, we consider the output of the checkpointed network as ground thruth i.e. we must compute a loss between two logits.See :class:`mnets.classifier_interface.Classifier` for a detailed describtion of the different loss functions. Args: (....): See docstring of function :func:`train_tasks`. t: Task id. i: Current training iteration. net_copy: Copy/checkpoint of the classifier network before learning task ``t``. Returns: The loss between predictions and predictions of a checkpointed network or replayed data. """ all_Y_hat_ls = [] all_targets = [] # we have to choose from which embeddings (multiple?!) to sample from if config.class_incremental or config.single_class_replay: # if we trained every class with a different generator emb_num = t * config.out_dim else: # here samples from the whole task come from one generator emb_num = t # we have to choose from which embeddings to sample from if config.fake_data_full_range: ran = range(0, emb_num) bs_per_task = int(np.ceil(config.batch_size / emb_num)) else: random_t = np.random.randint(0, emb_num) ran = range(random_t, random_t + 1) bs_per_task = config.batch_size for re in ran: # exchange replay data with real data to compute upper bounds if config.upper_bound: real_batch = dhandlers_rp[re].next_train_batch(bs_per_task) X_fake = dhandlers_rp[re].input_to_torch_tensor(real_batch[0], device, mode='train') else: # get fake data if config.replay_method == 'gan': X_fake = sample_gan(dec, d_hnet, config, re, device, bs=bs_per_task) else: X_fake = sample_vae(dec, d_hnet, config, re, device, bs=bs_per_task) # save some fake data to the writer if i % 100 == 0: if X_fake.shape[0] >= 15: fig_fake = _plotImages(X_fake, config, bs_per_task) writer.add_figure('train_class_' + str(re) + '_fake', fig_fake, global_step=i) # compute soft targets with copied network target_logits = net_copy.forward(X_fake).detach() Y_hat_ls = net.forward(X_fake.detach()) ############### # BUILD TARGETS ############### od = config.out_dim if config.class_incremental or config.training_task_infer: # This is a bit complicated: If we train class/task incrementally # we skip thraining the classifier on the first task. # So when starting to train the classifier on task 2, we have to # build a hard target for this first output neuron trained by # replay data. A soft target (on an untrained output) would not # make sense. # output head over all output neurons already available task_out = [0, (t + 1) * od] # create target with zero everywhere except from the current re zeros = torch.zeros(target_logits[:, 0:(t + 1) * od].shape).to(device) if config.hard_targets or (t == 1 and re == 0): zeros[:, re] = 1 else: zeros[:, 0:t * od] = target_logits[:, 0:t * od] targets = zeros Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]] elif config.cl_scenario == 1 or config.cl_scenario == 2: if config.cl_scenario == 1: # take the task specific output neuron task_out = [re * od, re * od + od] else: # always all output neurons, only one head is used task_out = [0, od] Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]] target_logits = target_logits[:, task_out[0]:task_out[1]] # build hard targets i.e. one hots if this option is chosen if config.hard_targets: soft_targets = torch.sigmoid(target_logits) zeros = torch.zeros(Y_hat_ls.shape).to(device) _, argmax = torch.max(soft_targets, 1) targets = zeros.scatter_(1, argmax.view(-1, 1), 1) else: # loss expects logits targets = target_logits else: # take all neurons used up until now # output head over all output neurons already available task_out = [0, (t + 1) * od] # create target with zero everywhere except from the current re zeros = torch.zeros(target_logits[:, 0:(t + 1) * od].shape).to(device) # sigmoid over the output head(s) from all previous task soft_targets = torch.sigmoid(target_logits[:, 0:t * od]) # compute one hots if config.hard_targets: _, argmax = torch.max(soft_targets, 1) zeros.scatter_(1, argmax.view(-1, 1), 1) else: # loss expects logits zeros[:, 0:t * od] = target_logits[:, 0:t * od] targets = zeros # choose the correct output size for the actual Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]] # add to list all_targets.append(targets) all_Y_hat_ls.append(Y_hat_ls) # cat to one tensor all_targets = torch.cat(all_targets) Y_hat_ls = torch.cat(all_Y_hat_ls) if i % 200 == 0: classifier_accuracy = Classifier.accuracy(Y_hat_ls, all_targets) * 100.0 msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \ '(on current FAKE DATA training batch).' print(msg.format(i, classifier_accuracy)) # dependent on the target softness, the loss function is chosen if config.hard_targets or (config.class_incremental and t == 1): return Classifier.logit_cross_entropy_loss(Y_hat_ls, all_targets) else: return Classifier.knowledge_distillation_loss(Y_hat_ls, all_targets)