def train(task_id, data, mnet, device, config, shared, logger, writer):
    r"""Train a network continually using EWC.

    In general, we consider networks with task shared weights :math:`\theta` and
    task-specific weights (usually the output head weights) :math:`\psi_t`. The
    EWC loss function then arises from the following identity.

    .. math::

        \log p(\theta, \psi_A, \cdots, \psi_T \mid \mathcal{D}_A, \cdots \
             \mathcal{D}_T) &= \log p(\mathcal{D}_T \mid \theta, \psi_T) + \
            \log p(\psi_T) + \sum_{t < T} \bigg[ \log p(\mathcal{D}_t \mid \
            \theta, \psi_t)  + \log p(\psi_t) \bigg] + \log p(\theta) + const \
            \\  &= \log p(\mathcal{D}_T \mid \theta, \psi_T) + \log p(\psi_T) \
            + \log p(\theta, \psi_A \cdots \psi_S \mid \mathcal{D}_A \cdots \
            \mathcal{D}_S) + const

    If there is a single head (or combined head/softmax) such that there are no
    task-specific weights, the :math:`\psi_t`'s can be dropped from the
    equation.

    The (online) EWC loss function can then be derived to be

    .. math::

        \log p(\theta, \psi_A, \cdots, \psi_T \mid \mathcal{D}_A, \cdots \
            \mathcal{D}_T) &\approx  const + \log p(\mathcal{D}_T \mid \theta, \
            \psi_T) + \log p(\psi_T) \\ \
            & \hspace{1cm} - \frac{1}{2} \sum_{i \in \mid \phi \mid} \bigg( \
            \frac{1}{\sigma_{prior}^2} + \sum_{t \in {A \cdots S}}  N_t \
            \mathcal{F}_{emp \: t, i}  \bigg) (\phi_i - \phi_{S, i}^*)^2

    where :math:`\phi` refers to all task-shared weights as well as all
    task-specific weights of previously seen tasks.

    Hence, each weight has its own regularization factor computed as a sum from
    a constant offset (assuming an isotropic prior) and a weighted accumulation
    of Fisher values from all previous tasks. Note, Fisher values of
    task-specific weights are only non-zero when computed on the corresponding
    task.

    As only task-shared and the current output head are being learned, the
    regularizer is trivially zero for all other task-specific weights.

    When learning the first task, we need to find a MAP solution by finding the
    argmax of:

    .. math::

        \log p(\theta, \psi_A \mid \mathcal{D}_A) =  const + \
            \log p(\mathcal{D}_A \mid \theta, \psi_A) + \log p(\theta) +\
            \log p(\psi_A)

    We assume isotropic Gaussian posteriors and therefore can transform the
    prior terms into simple L2 regularization (or weight decay) expressions:

    .. math::

        \log p(\theta) = -\frac{1}{2 \sigma_{prior}^2} \lVert \theta \rVert_2^2

    Args:
        task_id: The index of the task on which we train.
        data: The dataset handler.
        mnet: The model of the main network.
        device: Torch device (cpu or gpu).
        config: The command line arguments.
        shared: Miscellaneous data shared among training functions.
        logger: Command-line logger.
        writer: The tensorboard summary writer.
    """
    logger.info('Training network on task %d ...' % (task_id+1))

    mnet.train()

    # Whether we train a classification or regression task?
    is_regression = 'regression' in shared.experiment_type
    # If we have a multihead setting, then we need to distinguish between
    # task-specific and task-shared weights.
    is_multihead = None
    if is_regression:
        assert config.ll_dist_std > 0
        eval_func = reg_bbb.evaluate
        ll_scale = 1. / config.ll_dist_std**2
        is_multihead = config.multi_head
    else:
        assert shared.softmax_temp[task_id] == 1.
        eval_func = class_bbb.evaluate
        is_multihead = config.cl_scenario == 1 or \
            config.cl_scenario == 3 and config.split_head_cl3

    # Which outputs should we consider from the main network for the current
    # task.
    allowed_outputs = pmutils.out_units_of_task(config, data, task_id,
                                                task_id+1)

    #############################################################
    ### Figure out which are task-specific and shared weights ###
    #############################################################

    if is_multihead:
        # Note, that output weights of all output heads share always the same
        # parameter tensors, which is the case at the time of implementation
        # for all mnets.
        out_masks = mnet.get_output_weight_mask(out_inds=allowed_outputs,
                                                device=device)

        shared_params = []
        specific_params = []
        # Within an output weight tensor, we only want to apply the L2 reg to
        # the corresponding output weights.
        specific_mask = []

        for ii, mask in enumerate(out_masks):
            pind = mnet.param_shapes_meta[ii]['index']
            assert pind != -1
            if mask is None: # Shared parameter.
                shared_params.append(mnet.internal_params[pind])
            else: # Output weight tensor.
                specific_params.append(mnet.internal_params[pind])
                specific_mask.append(mask)
    else: # All weights are task-shared.
        shared_params = mnet.internal_params
        specific_params = None

    ###########################
    ### Create optimizer(s) ###
    ###########################
    # For the non-multihead case, we could invoke the L2 reg via the
    # weight-decay parameter here. But for the multihead case, we need to apply
    # an extra mask to the parameter tensor.
    optimizer = tutils.get_optimizer(mnet.internal_params, config.lr,
        momentum=config.momentum, weight_decay=config.weight_decay,
        use_adam=config.use_adam, adam_beta1=config.adam_beta1,
        use_rmsprop=config.use_rmsprop, use_adadelta=config.use_adadelta,
        use_adagrad=config.use_adagrad)

    ################################
    ### Learning rate schedulers ###
    ################################
    plateau_scheduler = None
    lambda_scheduler = None
    if config.plateau_lr_scheduler:
        assert config.epochs != -1
        plateau_scheduler = optim.lr_scheduler.ReduceLROnPlateau( \
            optimizer, 'min' if is_regression else 'max', factor=np.sqrt(0.1),
            patience=5, min_lr=0.5e-6, cooldown=0)

    if config.lambda_lr_scheduler:
        assert config.epochs != -1

        lambda_scheduler = optim.lr_scheduler.LambdaLR(optimizer,
            tutils.lambda_lr_schedule)

    ######################
    ### Start training ###
    ######################
    mnet_kwargs = pmutils.mnet_kwargs(config, task_id, mnet)

    num_train_iter, iter_per_epoch = sutils.calc_train_iter( \
        data.num_train_samples, config.batch_size, num_iter=config.n_iter,
        epochs=config.epochs)

    for i in range(num_train_iter):
        #########################
        ### Evaluate networks ###
        #########################
        # We test the network before we run the training iteration.
        # That way, we can see the initial performance of the untrained network.
        if i % config.val_iter == 0:
            eval_func(task_id, data, mnet, None, device, config, shared, logger,
                      writer, i)
            mnet.train()

        if i % 100 == 0:
            logger.debug('Training iteration: %d.' % i)

        ##########################
        ### Train Current Task ###
        ##########################
        optimizer.zero_grad()

        ### Compute negative log-likelihood (NLL).
        batch = data.next_train_batch(config.batch_size)
        X = data.input_to_torch_tensor(batch[0], device, mode='train')
        T = data.output_to_torch_tensor(batch[1], device, mode='train')

        if not is_regression:
            # Modify 1-hot encodings according to CL scenario.
            assert(T.shape[1] == data.num_classes)
            # Modify the targets, if softmax spans multiple heads.
            T = pmutils.fit_targets_to_softmax(config, shared, device, data,
                                               task_id, T)

            _, labels = torch.max(T, 1) # Integer labels.
            labels = labels.detach()

        Y = mnet.forward(X, **mnet_kwargs)
        if allowed_outputs is not None:
            Y = Y[:, allowed_outputs]

        # Task-specific loss.
        # We use the reduction method 'mean' on purpose and scale with
        # the number of training samples below.
        if is_regression:
            loss_nll = 0.5 * ll_scale * F.mse_loss(Y, T, reduction='mean')
        else:
            # Note, that `cross_entropy` also computed the softmax for us.
            loss_nll = F.cross_entropy(Y, labels, reduction='mean')

            # Compute accuracy on batch.
            # Note, softmax wouldn't change the argmax.
            _, pred_labels = torch.max(Y, 1)
            mean_train_acc = 100. * torch.sum(pred_labels == labels) / \
                config.batch_size

        loss_nll *= data.num_train_samples

        ### Compute L2 reg.
        loss_l2 = 0
        if task_id == 0 or config.train_from_scratch:
            for pp in shared_params:
                loss_l2 += pp.pow(2).sum()
        if specific_params is not None:
            for ii, pp in enumerate(specific_params):
                loss_l2 += (pp * specific_mask[ii]).pow(2).sum()
        loss_l2 *= 1. / (2. * config.prior_variance)

        ### Compute EWC reg.
        loss_ewc = 0
        if task_id > 0 and config.ewc_lambda > 0:
            assert not config.train_from_scratch
            loss_ewc += ewc.ewc_regularizer(task_id, mnet.internal_params,
                mnet, online=True, gamma=config.ewc_gamma)

        loss = loss_nll + loss_l2 + config.ewc_lambda * loss_ewc

        loss.backward()
        if config.clip_grad_value != -1:
            torch.nn.utils.clip_grad_value_( \
                optimizer.param_groups[0]['params'], config.clip_grad_value)
        elif config.clip_grad_norm != -1:
            torch.nn.utils.clip_grad_norm_(optimizer.param_groups[0]['params'],
                                           config.clip_grad_norm)
        optimizer.step()

        ###############################
        ### Learning rate scheduler ###
        ###############################
        # We can invoke the same function to compute test accuracy as we do for
        # BbB.
        pmutils.apply_lr_schedulers(config, shared, logger, task_id, data, mnet,
            None, device, i, iter_per_epoch, plateau_scheduler,
            lambda_scheduler, hhnet=None, method='bbb')

        ###########################
        ### Tensorboard summary ###
        ###########################
        if i % 50 == 0:
            writer.add_scalar('train/task_%d/loss_nll' % task_id, loss_nll, i)
            writer.add_scalar('train/task_%d/loss_l2' % task_id, loss_l2, i)
            writer.add_scalar('train/task_%d/loss_ewc' % task_id, loss_ewc, i)
            writer.add_scalar('train/task_%d/loss' % task_id, loss, i)
            if not is_regression:
                writer.add_scalar('train/task_%d/accuracy' % task_id,
                                  mean_train_acc, i)

    pmutils.checkpoint_bn_stats(config, task_id, mnet)

    #############################
    ### Compute Fisher matrix ###
    #############################
    # Note, we compute the Fisher after all tasks (even the last task) if we
    # have a multihead setup, since we use those Fisher values to build
    # approximate posterior distributions.
    if is_multihead or task_id < config.num_tasks - 1:
        logger.debug('Computing diagonal Fisher elements ...')

        fisher_params = mnet.internal_params

        # When training from scratch, new networks are generated every round
        # such that the old Fisher matrices as expected by EWC are not existing
        # yet.
        # On the other hand, if the hypernetwork is used, then we learn task-
        # specific models and we have to explicitly avoid that Fisher matrices
        # are accumulated.
        if task_id > 0 and config.train_from_scratch:
            for i, p in enumerate(fisher_params):
                buff_w_name, buff_f_name = ewc._ewc_buffer_names(task_id, i,
                                                                 True)
                mnet.register_buffer(buff_w_name, torch.zeros_like(p))
                mnet.register_buffer(buff_f_name, torch.zeros_like(p))

        # Compute prior-offset of Fisher values.
        if is_multihead:
            out_masks = mnet.get_output_weight_mask(out_inds=allowed_outputs,
                                                    device=device)

            prior_offset = [torch.zeros_like(p) for p in mnet.internal_params]

            for ii, mask in enumerate(out_masks):
                pind = mnet.param_shapes_meta[ii]['index']

                if mask is None: # Shared parameter.
                    if task_id == 0 or config.train_from_scratch:
                        prior_offset[pind][:] = 1. / config.prior_variance
                else: # Current output head.
                    # Note, why don't I apply the offset from the beginning to
                    # all heads?
                    # -> If I would, then Fisher values of output heads of
                    # the current and future tasks would be non-zero and
                    # therefore the corresponding weights would be regularized
                    # by the EWC regularizer. For future tasks this doesn't
                    # matter, as the weights don't change during training and
                    # the reg is still 0. But for the current task this does
                    # matter and therefore the reg would pull the weights
                    # towards the random initialization.
                    prior_offset[pind][mask] = 1. / config.prior_variance

        else:
            prior_offset = 0
            if task_id == 0 or config.train_from_scratch:
                prior_offset = 1. / config.prior_variance

        target_manipulator = None
        if not is_regression:
            target_manipulator = lambda T: pmutils.fit_targets_to_softmax( \
                config, shared, device, data, task_id, T)

        ewc.compute_fisher(task_id, data, fisher_params, device, mnet,
            empirical_fisher=True, online=True, gamma=config.ewc_gamma,
            n_max=config.n_fisher, regression=is_regression,
            allowed_outputs=allowed_outputs, custom_forward=None,
            time_series=False, custom_nll=None, pass_ids=False,
            proper_scaling=True, prior_strength=prior_offset,
            regression_lvar=config.ll_dist_std**2 if is_regression else 1.,
            target_manipulator=target_manipulator)

        ### Log histogram of diagonal Fisher elements.
        diag_fisher = []

        out_masks = mnet.get_output_weight_mask(out_inds=allowed_outputs,
                                                device=device)
        for ii, mask in enumerate(out_masks):
            pind = mnet.param_shapes_meta[ii]['index']
            _, buff_f_name = ewc._ewc_buffer_names(None, pind, True)
            curr_F = getattr(mnet, buff_f_name)
            if mask is not None:
                curr_F = curr_F[mask]
            diag_fisher.append(curr_F)

        diag_fisher = torch.cat([p.detach().flatten().cpu() for p in \
                                 diag_fisher])

        writer.add_scalar('ewc/min_fisher', torch.min(diag_fisher), task_id)
        writer.add_scalar('ewc/max_fisher', torch.max(diag_fisher), task_id)
        writer.add_histogram('ewc/fisher', diag_fisher, task_id)
        try:
            writer.add_histogram('ewc/log_fisher', torch.log(diag_fisher),
                                 task_id)
        except:
            # Should not happen, since diagonal elements should be positive.
            logger.warn('Could not write histogram of diagonal fisher ' +
                        'elements.')

    logger.info('Training network on task %d ... Done' % (task_id+1))
Пример #2
0
def train_reg(task_id, data, mnet, hnet, device, config, writer):
    r"""Train the network using the task-specific loss plus a regularizer that
    should weaken catastrophic forgetting.

    .. math::
        \text{loss} = \text{task\_loss} + \beta * \text{regularizer}

    Args:
        task_id: The index of the task on which we train.
        data: The dataset handler.
        mnet: The model of the main network.
        hnet: The model of the hyoer network.
        device: Torch device (cpu or gpu).
        config: The command line arguments.
        writer: The tensorboard summary writer.
    """
    print('Training network ...')

    mnet.train()
    hnet.train()

    regged_outputs = None
    if config.multi_head:
        n_y = data.out_shape[0]
        out_head_inds = [list(range(i * n_y, (i + 1) * n_y)) for i in
                         range(task_id + 1)]
        # Outputs to be regularized.
        regged_outputs = out_head_inds[:-1] if config.masked_reg else None
    allowed_outputs = out_head_inds[task_id] if config.multi_head else None

    # Collect Fisher estimates for the reg computation.
    fisher_ests = None
    if config.ewc_weight_importance and task_id > 0:
        fisher_ests = []
        n_W = len(hnet.target_shapes)
        for t in range(task_id):
            ff = []
            for i in range(n_W):
                _, buff_f_name = ewc._ewc_buffer_names(t, i, False)
                ff.append(getattr(mnet, buff_f_name))
            fisher_ests.append(ff)

    # Regularizer targets.
    if config.reg == 0 and config.beta > 0:
        targets = hreg.get_current_targets(task_id, hnet)

    regularized_params = list(hnet.theta)
    if task_id > 0 and config.plastic_prev_tembs:
        assert (config.reg == 0)
        for i in range(task_id):  # for all previous task embeddings
            regularized_params.append(hnet.get_task_emb(i))
    theta_optimizer = optim.Adam(regularized_params, lr=config.lr_hyper)
    # We only optimize the task embedding corresponding to the current task,
    # the remaining ones stay constant.
    emb_optimizer = optim.Adam([hnet.get_task_emb(task_id)],
                               lr=config.lr_hyper)

    # Whether the regularizer will be computed during training?
    calc_reg = task_id > 0 and config.beta > 0

    for i in range(config.n_iter):
        ### Evaluate network.
        # We test the network before we run the training iteration.
        # That way, we can see the initial performance of the untrained network.
        if i % config.val_iter == 0:
            evaluate(task_id, data, mnet, hnet, device, config, writer, i)
            mnet.train()
            hnet.train()

        if i % 100 == 0:
            print('Training iteration: %d.' % i)

        ### Train theta and task embedding.
        theta_optimizer.zero_grad()
        emb_optimizer.zero_grad()

        batch = data.next_train_batch(config.batch_size)
        X = data.input_to_torch_tensor(batch[0], device, mode='train')
        T = data.output_to_torch_tensor(batch[1], device, mode='train')

        weights = hnet.forward(task_id)
        Y = mnet.forward(X, weights)
        if config.multi_head:
            Y = Y[:, allowed_outputs]

        # Task-specific loss.
        loss_task = F.mse_loss(Y, T)
        # We already compute the gradients, to then be able to compute delta
        # theta.
        loss_task.backward(retain_graph=calc_reg,
                           create_graph=config.backprop_dt and calc_reg)

        # The task embedding is only trained on the task-specific loss.
        # Note, the gradients accumulated so far are from "loss_task".
        emb_optimizer.step()

        # DELETEME check correctness of opstep.calc_delta_theta.
        # dPrev = torch.cat([d.data.clone().view(-1) for d in hnet.theta])
        # dT_estimate = torch.cat([d.view(-1).clone() for d in
        #    opstep.calc_delta_theta(theta_optimizer,
        #                            config.use_sgd_change, lr=config.lr_hyper,
        #                            detach_dt=not config.backprop_dt)])

        loss_reg = 0
        dTheta = None
        grad_tloss = None
        if calc_reg:
            if i % 100 == 0:  # Just for debugging: displaying grad magnitude.
                grad_tloss = torch.cat([d.grad.clone().view(-1) for d in
                                        hnet.theta])

            dTheta = opstep.calc_delta_theta(theta_optimizer,
                                             config.use_sgd_change, lr=config.lr_hyper,
                                             detach_dt=not config.backprop_dt)
            if config.plastic_prev_tembs:
                dTembs = dTheta[-task_id:]
                dTheta = dTheta[:-task_id]
            else:
                dTembs = None

            if config.reg == 0:
                loss_reg = hreg.calc_fix_target_reg(hnet, task_id,
                                                    targets=targets, dTheta=dTheta, dTembs=dTembs, mnet=mnet,
                                                    inds_of_out_heads=regged_outputs,
                                                    fisher_estimates=fisher_ests)
            elif config.reg == 1:
                loss_reg = hreg.calc_value_preserving_reg(hnet, task_id, dTheta)
            elif config.reg == 2:
                loss_reg = hreg.calc_jac_reguarizer(hnet, task_id, dTheta,
                                                    device)
            elif config.reg == 3:  # EWC
                loss_reg = ewc.ewc_regularizer(task_id, hnet.theta, None,
                                               hnet=hnet, online=config.online_ewc, gamma=config.gamma)
            loss_reg *= config.beta

            loss_reg.backward()

            if grad_tloss is not None:
                grad_full = torch.cat([d.grad.view(-1) for d in hnet.theta])
                # Grad of regularizer.
                grad_diff = grad_full - grad_tloss
                grad_diff_norm = torch.norm(grad_diff, 2)

                # Cosine between regularizer gradient and task-specific
                # gradient.
                dT_vec = torch.cat([d.view(-1).clone() for d in dTheta])
                grad_cos = F.cosine_similarity(grad_diff.view(1, -1),
                                               dT_vec.view(1, -1))

        theta_optimizer.step()

        # DELETEME
        # dCurr = torch.cat([d.data.view(-1) for d in hnet.theta])
        # dT_actual = dCurr - dPrev
        # print(torch.norm(dT_estimate - dT_actual, 2))

        if i % 10 == 0:
            writer.add_scalar('train/task_%d/mse_loss' % task_id, loss_task, i)
            writer.add_scalar('train/task_%d/regularizer' % task_id, loss_reg,
                              i)
            writer.add_scalar('train/task_%d/full_loss' % task_id, loss_task +
                              loss_reg, i)
            if dTheta is not None:
                dT_norm = torch.norm(torch.cat([d.view(-1) for d in dTheta]), 2)
                writer.add_scalar('train/task_%d/dTheta_norm' % task_id,
                                  dT_norm, i)
            if grad_tloss is not None:
                writer.add_scalar('train/task_%d/full_grad_norm' % task_id,
                                  torch.norm(grad_full, 2), i)
                writer.add_scalar('train/task_%d/reg_grad_norm' % task_id,
                                  grad_diff_norm, i)
                writer.add_scalar('train/task_%d/cosine_task_reg' % task_id,
                                  grad_cos, i)

    if config.reg == 3:
        ## Estimate diagonal Fisher elements.
        ewc.compute_fisher(task_id, data, hnet.theta, device, mnet, hnet=hnet,
                           empirical_fisher=True, online=config.online_ewc, gamma=config.gamma,
                           n_max=config.n_fisher, regression=True,
                           allowed_outputs=allowed_outputs)

    if config.ewc_weight_importance:
        ## Estimate Fisher for outputs of the hypernetwork.
        weights = hnet.forward(task_id)

        # Note, there are actually no parameters in the main network.
        fake_main_params = nn.ParameterList()
        for i, W in enumerate(weights):
            fake_main_params.append(nn.Parameter(torch.Tensor(*W.shape),
                                                 requires_grad=True))
            fake_main_params[i].data = weights[i]

        ewc.compute_fisher(task_id, data, fake_main_params, device, mnet,
                           empirical_fisher=True, online=False, n_max=config.n_fisher,
                           regression=True, allowed_outputs=allowed_outputs)

    print('Training network ... Done')