def train(task_id, data, mnet, device, config, shared, logger, writer): r"""Train a network continually using EWC. In general, we consider networks with task shared weights :math:`\theta` and task-specific weights (usually the output head weights) :math:`\psi_t`. The EWC loss function then arises from the following identity. .. math:: \log p(\theta, \psi_A, \cdots, \psi_T \mid \mathcal{D}_A, \cdots \ \mathcal{D}_T) &= \log p(\mathcal{D}_T \mid \theta, \psi_T) + \ \log p(\psi_T) + \sum_{t < T} \bigg[ \log p(\mathcal{D}_t \mid \ \theta, \psi_t) + \log p(\psi_t) \bigg] + \log p(\theta) + const \ \\ &= \log p(\mathcal{D}_T \mid \theta, \psi_T) + \log p(\psi_T) \ + \log p(\theta, \psi_A \cdots \psi_S \mid \mathcal{D}_A \cdots \ \mathcal{D}_S) + const If there is a single head (or combined head/softmax) such that there are no task-specific weights, the :math:`\psi_t`'s can be dropped from the equation. The (online) EWC loss function can then be derived to be .. math:: \log p(\theta, \psi_A, \cdots, \psi_T \mid \mathcal{D}_A, \cdots \ \mathcal{D}_T) &\approx const + \log p(\mathcal{D}_T \mid \theta, \ \psi_T) + \log p(\psi_T) \\ \ & \hspace{1cm} - \frac{1}{2} \sum_{i \in \mid \phi \mid} \bigg( \ \frac{1}{\sigma_{prior}^2} + \sum_{t \in {A \cdots S}} N_t \ \mathcal{F}_{emp \: t, i} \bigg) (\phi_i - \phi_{S, i}^*)^2 where :math:`\phi` refers to all task-shared weights as well as all task-specific weights of previously seen tasks. Hence, each weight has its own regularization factor computed as a sum from a constant offset (assuming an isotropic prior) and a weighted accumulation of Fisher values from all previous tasks. Note, Fisher values of task-specific weights are only non-zero when computed on the corresponding task. As only task-shared and the current output head are being learned, the regularizer is trivially zero for all other task-specific weights. When learning the first task, we need to find a MAP solution by finding the argmax of: .. math:: \log p(\theta, \psi_A \mid \mathcal{D}_A) = const + \ \log p(\mathcal{D}_A \mid \theta, \psi_A) + \log p(\theta) +\ \log p(\psi_A) We assume isotropic Gaussian posteriors and therefore can transform the prior terms into simple L2 regularization (or weight decay) expressions: .. math:: \log p(\theta) = -\frac{1}{2 \sigma_{prior}^2} \lVert \theta \rVert_2^2 Args: task_id: The index of the task on which we train. data: The dataset handler. mnet: The model of the main network. device: Torch device (cpu or gpu). config: The command line arguments. shared: Miscellaneous data shared among training functions. logger: Command-line logger. writer: The tensorboard summary writer. """ logger.info('Training network on task %d ...' % (task_id+1)) mnet.train() # Whether we train a classification or regression task? is_regression = 'regression' in shared.experiment_type # If we have a multihead setting, then we need to distinguish between # task-specific and task-shared weights. is_multihead = None if is_regression: assert config.ll_dist_std > 0 eval_func = reg_bbb.evaluate ll_scale = 1. / config.ll_dist_std**2 is_multihead = config.multi_head else: assert shared.softmax_temp[task_id] == 1. eval_func = class_bbb.evaluate is_multihead = config.cl_scenario == 1 or \ config.cl_scenario == 3 and config.split_head_cl3 # Which outputs should we consider from the main network for the current # task. allowed_outputs = pmutils.out_units_of_task(config, data, task_id, task_id+1) ############################################################# ### Figure out which are task-specific and shared weights ### ############################################################# if is_multihead: # Note, that output weights of all output heads share always the same # parameter tensors, which is the case at the time of implementation # for all mnets. out_masks = mnet.get_output_weight_mask(out_inds=allowed_outputs, device=device) shared_params = [] specific_params = [] # Within an output weight tensor, we only want to apply the L2 reg to # the corresponding output weights. specific_mask = [] for ii, mask in enumerate(out_masks): pind = mnet.param_shapes_meta[ii]['index'] assert pind != -1 if mask is None: # Shared parameter. shared_params.append(mnet.internal_params[pind]) else: # Output weight tensor. specific_params.append(mnet.internal_params[pind]) specific_mask.append(mask) else: # All weights are task-shared. shared_params = mnet.internal_params specific_params = None ########################### ### Create optimizer(s) ### ########################### # For the non-multihead case, we could invoke the L2 reg via the # weight-decay parameter here. But for the multihead case, we need to apply # an extra mask to the parameter tensor. optimizer = tutils.get_optimizer(mnet.internal_params, config.lr, momentum=config.momentum, weight_decay=config.weight_decay, use_adam=config.use_adam, adam_beta1=config.adam_beta1, use_rmsprop=config.use_rmsprop, use_adadelta=config.use_adadelta, use_adagrad=config.use_adagrad) ################################ ### Learning rate schedulers ### ################################ plateau_scheduler = None lambda_scheduler = None if config.plateau_lr_scheduler: assert config.epochs != -1 plateau_scheduler = optim.lr_scheduler.ReduceLROnPlateau( \ optimizer, 'min' if is_regression else 'max', factor=np.sqrt(0.1), patience=5, min_lr=0.5e-6, cooldown=0) if config.lambda_lr_scheduler: assert config.epochs != -1 lambda_scheduler = optim.lr_scheduler.LambdaLR(optimizer, tutils.lambda_lr_schedule) ###################### ### Start training ### ###################### mnet_kwargs = pmutils.mnet_kwargs(config, task_id, mnet) num_train_iter, iter_per_epoch = sutils.calc_train_iter( \ data.num_train_samples, config.batch_size, num_iter=config.n_iter, epochs=config.epochs) for i in range(num_train_iter): ######################### ### Evaluate networks ### ######################### # We test the network before we run the training iteration. # That way, we can see the initial performance of the untrained network. if i % config.val_iter == 0: eval_func(task_id, data, mnet, None, device, config, shared, logger, writer, i) mnet.train() if i % 100 == 0: logger.debug('Training iteration: %d.' % i) ########################## ### Train Current Task ### ########################## optimizer.zero_grad() ### Compute negative log-likelihood (NLL). batch = data.next_train_batch(config.batch_size) X = data.input_to_torch_tensor(batch[0], device, mode='train') T = data.output_to_torch_tensor(batch[1], device, mode='train') if not is_regression: # Modify 1-hot encodings according to CL scenario. assert(T.shape[1] == data.num_classes) # Modify the targets, if softmax spans multiple heads. T = pmutils.fit_targets_to_softmax(config, shared, device, data, task_id, T) _, labels = torch.max(T, 1) # Integer labels. labels = labels.detach() Y = mnet.forward(X, **mnet_kwargs) if allowed_outputs is not None: Y = Y[:, allowed_outputs] # Task-specific loss. # We use the reduction method 'mean' on purpose and scale with # the number of training samples below. if is_regression: loss_nll = 0.5 * ll_scale * F.mse_loss(Y, T, reduction='mean') else: # Note, that `cross_entropy` also computed the softmax for us. loss_nll = F.cross_entropy(Y, labels, reduction='mean') # Compute accuracy on batch. # Note, softmax wouldn't change the argmax. _, pred_labels = torch.max(Y, 1) mean_train_acc = 100. * torch.sum(pred_labels == labels) / \ config.batch_size loss_nll *= data.num_train_samples ### Compute L2 reg. loss_l2 = 0 if task_id == 0 or config.train_from_scratch: for pp in shared_params: loss_l2 += pp.pow(2).sum() if specific_params is not None: for ii, pp in enumerate(specific_params): loss_l2 += (pp * specific_mask[ii]).pow(2).sum() loss_l2 *= 1. / (2. * config.prior_variance) ### Compute EWC reg. loss_ewc = 0 if task_id > 0 and config.ewc_lambda > 0: assert not config.train_from_scratch loss_ewc += ewc.ewc_regularizer(task_id, mnet.internal_params, mnet, online=True, gamma=config.ewc_gamma) loss = loss_nll + loss_l2 + config.ewc_lambda * loss_ewc loss.backward() if config.clip_grad_value != -1: torch.nn.utils.clip_grad_value_( \ optimizer.param_groups[0]['params'], config.clip_grad_value) elif config.clip_grad_norm != -1: torch.nn.utils.clip_grad_norm_(optimizer.param_groups[0]['params'], config.clip_grad_norm) optimizer.step() ############################### ### Learning rate scheduler ### ############################### # We can invoke the same function to compute test accuracy as we do for # BbB. pmutils.apply_lr_schedulers(config, shared, logger, task_id, data, mnet, None, device, i, iter_per_epoch, plateau_scheduler, lambda_scheduler, hhnet=None, method='bbb') ########################### ### Tensorboard summary ### ########################### if i % 50 == 0: writer.add_scalar('train/task_%d/loss_nll' % task_id, loss_nll, i) writer.add_scalar('train/task_%d/loss_l2' % task_id, loss_l2, i) writer.add_scalar('train/task_%d/loss_ewc' % task_id, loss_ewc, i) writer.add_scalar('train/task_%d/loss' % task_id, loss, i) if not is_regression: writer.add_scalar('train/task_%d/accuracy' % task_id, mean_train_acc, i) pmutils.checkpoint_bn_stats(config, task_id, mnet) ############################# ### Compute Fisher matrix ### ############################# # Note, we compute the Fisher after all tasks (even the last task) if we # have a multihead setup, since we use those Fisher values to build # approximate posterior distributions. if is_multihead or task_id < config.num_tasks - 1: logger.debug('Computing diagonal Fisher elements ...') fisher_params = mnet.internal_params # When training from scratch, new networks are generated every round # such that the old Fisher matrices as expected by EWC are not existing # yet. # On the other hand, if the hypernetwork is used, then we learn task- # specific models and we have to explicitly avoid that Fisher matrices # are accumulated. if task_id > 0 and config.train_from_scratch: for i, p in enumerate(fisher_params): buff_w_name, buff_f_name = ewc._ewc_buffer_names(task_id, i, True) mnet.register_buffer(buff_w_name, torch.zeros_like(p)) mnet.register_buffer(buff_f_name, torch.zeros_like(p)) # Compute prior-offset of Fisher values. if is_multihead: out_masks = mnet.get_output_weight_mask(out_inds=allowed_outputs, device=device) prior_offset = [torch.zeros_like(p) for p in mnet.internal_params] for ii, mask in enumerate(out_masks): pind = mnet.param_shapes_meta[ii]['index'] if mask is None: # Shared parameter. if task_id == 0 or config.train_from_scratch: prior_offset[pind][:] = 1. / config.prior_variance else: # Current output head. # Note, why don't I apply the offset from the beginning to # all heads? # -> If I would, then Fisher values of output heads of # the current and future tasks would be non-zero and # therefore the corresponding weights would be regularized # by the EWC regularizer. For future tasks this doesn't # matter, as the weights don't change during training and # the reg is still 0. But for the current task this does # matter and therefore the reg would pull the weights # towards the random initialization. prior_offset[pind][mask] = 1. / config.prior_variance else: prior_offset = 0 if task_id == 0 or config.train_from_scratch: prior_offset = 1. / config.prior_variance target_manipulator = None if not is_regression: target_manipulator = lambda T: pmutils.fit_targets_to_softmax( \ config, shared, device, data, task_id, T) ewc.compute_fisher(task_id, data, fisher_params, device, mnet, empirical_fisher=True, online=True, gamma=config.ewc_gamma, n_max=config.n_fisher, regression=is_regression, allowed_outputs=allowed_outputs, custom_forward=None, time_series=False, custom_nll=None, pass_ids=False, proper_scaling=True, prior_strength=prior_offset, regression_lvar=config.ll_dist_std**2 if is_regression else 1., target_manipulator=target_manipulator) ### Log histogram of diagonal Fisher elements. diag_fisher = [] out_masks = mnet.get_output_weight_mask(out_inds=allowed_outputs, device=device) for ii, mask in enumerate(out_masks): pind = mnet.param_shapes_meta[ii]['index'] _, buff_f_name = ewc._ewc_buffer_names(None, pind, True) curr_F = getattr(mnet, buff_f_name) if mask is not None: curr_F = curr_F[mask] diag_fisher.append(curr_F) diag_fisher = torch.cat([p.detach().flatten().cpu() for p in \ diag_fisher]) writer.add_scalar('ewc/min_fisher', torch.min(diag_fisher), task_id) writer.add_scalar('ewc/max_fisher', torch.max(diag_fisher), task_id) writer.add_histogram('ewc/fisher', diag_fisher, task_id) try: writer.add_histogram('ewc/log_fisher', torch.log(diag_fisher), task_id) except: # Should not happen, since diagonal elements should be positive. logger.warn('Could not write histogram of diagonal fisher ' + 'elements.') logger.info('Training network on task %d ... Done' % (task_id+1))
def train_reg(task_id, data, mnet, hnet, device, config, writer): r"""Train the network using the task-specific loss plus a regularizer that should weaken catastrophic forgetting. .. math:: \text{loss} = \text{task\_loss} + \beta * \text{regularizer} Args: task_id: The index of the task on which we train. data: The dataset handler. mnet: The model of the main network. hnet: The model of the hyoer network. device: Torch device (cpu or gpu). config: The command line arguments. writer: The tensorboard summary writer. """ print('Training network ...') mnet.train() hnet.train() regged_outputs = None if config.multi_head: n_y = data.out_shape[0] out_head_inds = [list(range(i * n_y, (i + 1) * n_y)) for i in range(task_id + 1)] # Outputs to be regularized. regged_outputs = out_head_inds[:-1] if config.masked_reg else None allowed_outputs = out_head_inds[task_id] if config.multi_head else None # Collect Fisher estimates for the reg computation. fisher_ests = None if config.ewc_weight_importance and task_id > 0: fisher_ests = [] n_W = len(hnet.target_shapes) for t in range(task_id): ff = [] for i in range(n_W): _, buff_f_name = ewc._ewc_buffer_names(t, i, False) ff.append(getattr(mnet, buff_f_name)) fisher_ests.append(ff) # Regularizer targets. if config.reg == 0 and config.beta > 0: targets = hreg.get_current_targets(task_id, hnet) regularized_params = list(hnet.theta) if task_id > 0 and config.plastic_prev_tembs: assert (config.reg == 0) for i in range(task_id): # for all previous task embeddings regularized_params.append(hnet.get_task_emb(i)) theta_optimizer = optim.Adam(regularized_params, lr=config.lr_hyper) # We only optimize the task embedding corresponding to the current task, # the remaining ones stay constant. emb_optimizer = optim.Adam([hnet.get_task_emb(task_id)], lr=config.lr_hyper) # Whether the regularizer will be computed during training? calc_reg = task_id > 0 and config.beta > 0 for i in range(config.n_iter): ### Evaluate network. # We test the network before we run the training iteration. # That way, we can see the initial performance of the untrained network. if i % config.val_iter == 0: evaluate(task_id, data, mnet, hnet, device, config, writer, i) mnet.train() hnet.train() if i % 100 == 0: print('Training iteration: %d.' % i) ### Train theta and task embedding. theta_optimizer.zero_grad() emb_optimizer.zero_grad() batch = data.next_train_batch(config.batch_size) X = data.input_to_torch_tensor(batch[0], device, mode='train') T = data.output_to_torch_tensor(batch[1], device, mode='train') weights = hnet.forward(task_id) Y = mnet.forward(X, weights) if config.multi_head: Y = Y[:, allowed_outputs] # Task-specific loss. loss_task = F.mse_loss(Y, T) # We already compute the gradients, to then be able to compute delta # theta. loss_task.backward(retain_graph=calc_reg, create_graph=config.backprop_dt and calc_reg) # The task embedding is only trained on the task-specific loss. # Note, the gradients accumulated so far are from "loss_task". emb_optimizer.step() # DELETEME check correctness of opstep.calc_delta_theta. # dPrev = torch.cat([d.data.clone().view(-1) for d in hnet.theta]) # dT_estimate = torch.cat([d.view(-1).clone() for d in # opstep.calc_delta_theta(theta_optimizer, # config.use_sgd_change, lr=config.lr_hyper, # detach_dt=not config.backprop_dt)]) loss_reg = 0 dTheta = None grad_tloss = None if calc_reg: if i % 100 == 0: # Just for debugging: displaying grad magnitude. grad_tloss = torch.cat([d.grad.clone().view(-1) for d in hnet.theta]) dTheta = opstep.calc_delta_theta(theta_optimizer, config.use_sgd_change, lr=config.lr_hyper, detach_dt=not config.backprop_dt) if config.plastic_prev_tembs: dTembs = dTheta[-task_id:] dTheta = dTheta[:-task_id] else: dTembs = None if config.reg == 0: loss_reg = hreg.calc_fix_target_reg(hnet, task_id, targets=targets, dTheta=dTheta, dTembs=dTembs, mnet=mnet, inds_of_out_heads=regged_outputs, fisher_estimates=fisher_ests) elif config.reg == 1: loss_reg = hreg.calc_value_preserving_reg(hnet, task_id, dTheta) elif config.reg == 2: loss_reg = hreg.calc_jac_reguarizer(hnet, task_id, dTheta, device) elif config.reg == 3: # EWC loss_reg = ewc.ewc_regularizer(task_id, hnet.theta, None, hnet=hnet, online=config.online_ewc, gamma=config.gamma) loss_reg *= config.beta loss_reg.backward() if grad_tloss is not None: grad_full = torch.cat([d.grad.view(-1) for d in hnet.theta]) # Grad of regularizer. grad_diff = grad_full - grad_tloss grad_diff_norm = torch.norm(grad_diff, 2) # Cosine between regularizer gradient and task-specific # gradient. dT_vec = torch.cat([d.view(-1).clone() for d in dTheta]) grad_cos = F.cosine_similarity(grad_diff.view(1, -1), dT_vec.view(1, -1)) theta_optimizer.step() # DELETEME # dCurr = torch.cat([d.data.view(-1) for d in hnet.theta]) # dT_actual = dCurr - dPrev # print(torch.norm(dT_estimate - dT_actual, 2)) if i % 10 == 0: writer.add_scalar('train/task_%d/mse_loss' % task_id, loss_task, i) writer.add_scalar('train/task_%d/regularizer' % task_id, loss_reg, i) writer.add_scalar('train/task_%d/full_loss' % task_id, loss_task + loss_reg, i) if dTheta is not None: dT_norm = torch.norm(torch.cat([d.view(-1) for d in dTheta]), 2) writer.add_scalar('train/task_%d/dTheta_norm' % task_id, dT_norm, i) if grad_tloss is not None: writer.add_scalar('train/task_%d/full_grad_norm' % task_id, torch.norm(grad_full, 2), i) writer.add_scalar('train/task_%d/reg_grad_norm' % task_id, grad_diff_norm, i) writer.add_scalar('train/task_%d/cosine_task_reg' % task_id, grad_cos, i) if config.reg == 3: ## Estimate diagonal Fisher elements. ewc.compute_fisher(task_id, data, hnet.theta, device, mnet, hnet=hnet, empirical_fisher=True, online=config.online_ewc, gamma=config.gamma, n_max=config.n_fisher, regression=True, allowed_outputs=allowed_outputs) if config.ewc_weight_importance: ## Estimate Fisher for outputs of the hypernetwork. weights = hnet.forward(task_id) # Note, there are actually no parameters in the main network. fake_main_params = nn.ParameterList() for i, W in enumerate(weights): fake_main_params.append(nn.Parameter(torch.Tensor(*W.shape), requires_grad=True)) fake_main_params[i].data = weights[i] ewc.compute_fisher(task_id, data, fake_main_params, device, mnet, empirical_fisher=True, online=False, n_max=config.n_fisher, regression=True, allowed_outputs=allowed_outputs) print('Training network ... Done')