Beispiel #1
0
def train(task_id, data, mnet, hnet, device, config, shared, writer, logger):
    """Train the hyper network using the task-specific loss plus a regularizer
    that should overcome catastrophic forgetting.

    :code:`loss = task_loss + beta * regularizer`.

    Args:
        task_id: The index of the task on which we train.
        data: The dataset handler.
        mnet: The model of the main network.
        hnet: The model of the hyper network. May be ``None``.
        device: Torch device (cpu or gpu).
        config: The command line arguments.
        shared (argparse.Namespace): Set of variables shared between functions.
        writer: The tensorboard summary writer.
        logger: The logger that should be used rather than the print method.
    """
    start_time = time()

    print('data: ', data)
    print('data.num_classes: ', data.num_classes)
    print('data.num_train_samples: ', data.num_train_samples)

    logger.info('Training network ...')

    mnet.train()
    if hnet is not None:
        hnet.train()

    #################
    ### Optimizer ###
    #################
    # Define the optimizers used to train main network and hypernet.
    if hnet is not None:
        theta_params = list(hnet.theta)
        if config.continue_emb_training:
            for i in range(task_id):  # for all previous task embeddings
                theta_params.append(hnet.get_task_emb(i))

        # Only for the current task embedding.
        # Important that this embedding is in a different optimizer in case
        # we use the lookahead.
        emb_optimizer = get_optimizer([hnet.get_task_emb(task_id)],
                                      config.lr,
                                      momentum=config.momentum,
                                      weight_decay=config.weight_decay,
                                      use_adam=config.use_adam,
                                      adam_beta1=config.adam_beta1,
                                      use_rmsprop=config.use_rmsprop)
    else:
        theta_params = mnet.weights
        emb_optimizer = None

    theta_optimizer = get_optimizer(theta_params,
                                    config.lr,
                                    momentum=config.momentum,
                                    weight_decay=config.weight_decay,
                                    use_adam=config.use_adam,
                                    adam_beta1=config.adam_beta1,
                                    use_rmsprop=config.use_rmsprop)

    ################################
    ### Learning rate schedulers ###
    ################################
    if config.plateau_lr_scheduler:
        assert (config.epochs != -1)
        # The scheduler config has been taken from here:
        # https://keras.io/examples/cifar10_resnet/
        # Note, we use 'max' instead of 'min' as we look at accuracy rather
        # than validation loss!
        plateau_scheduler_theta = optim.lr_scheduler.ReduceLROnPlateau( \
            theta_optimizer, 'max', factor=np.sqrt(0.1), patience=5,
            min_lr=0.5e-6, cooldown=0)
        plateau_scheduler_emb = None
        if emb_optimizer is not None:
            plateau_scheduler_emb = optim.lr_scheduler.ReduceLROnPlateau( \
                emb_optimizer, 'max', factor=np.sqrt(0.1), patience=5,
                min_lr=0.5e-6, cooldown=0)

    if config.lambda_lr_scheduler:
        assert (config.epochs != -1)

        def lambda_lr(epoch):
            """Multiplicative Factor for Learning Rate Schedule.

            Computes a multiplicative factor for the initial learning rate based
            on the current epoch. This method can be used as argument
            ``lr_lambda`` of class :class:`torch.optim.lr_scheduler.LambdaLR`.

            The schedule is inspired by the Resnet CIFAR-10 schedule suggested
            here https://keras.io/examples/cifar10_resnet/.

            Args:
                epoch (int): The number of epochs

            Returns:
                lr_scale (float32): learning rate scale
            """
            lr_scale = 1.
            if epoch > 180:
                lr_scale = 0.5e-3
            elif epoch > 160:
                lr_scale = 1e-3
            elif epoch > 120:
                lr_scale = 1e-2
            elif epoch > 80:
                lr_scale = 1e-1
            return lr_scale

        lambda_scheduler_theta = optim.lr_scheduler.LambdaLR(
            theta_optimizer, lambda_lr)
        lambda_scheduler_emb = None
        if emb_optimizer is not None:
            lambda_scheduler_emb = optim.lr_scheduler.LambdaLR(
                emb_optimizer, lambda_lr)

    ##############################
    ### Prepare CL Regularizer ###
    ##############################
    # Whether we will calculate the regularizer.
    calc_reg = task_id > 0 and not config.mnet_only and config.beta > 0 and \
        not config.train_from_scratch

    # Compute targets when the reg is activated and we are not training
    # the first task
    if calc_reg:
        if config.online_target_computation:
            # Compute targets for the regularizer whenever they are needed.
            # -> Computationally expensive.
            targets_hypernet = None
            prev_theta = [p.detach().clone() for p in hnet.theta]
            prev_task_embs = [p.detach().clone() for p in hnet.get_task_embs()]
        else:
            # Compute targets for the regularizer once and keep them all in
            # memory -> Memory expensive.
            targets_hypernet = hreg.get_current_targets(task_id, hnet)
            prev_theta = None
            prev_task_embs = None

        # If we do not want to regularize all outputs (in a multi-head setup).
        # Note, we don't care whether output heads other than the current one
        # change.
        regged_outputs = None
        if config.cl_scenario != 2:
            # FIXME We assume here that all tasks have the same output size.
            n_y = data.num_classes
            regged_outputs = [
                list(range(i * n_y, (i + 1) * n_y)) for i in range(task_id)
            ]

    # We need to tell the main network, which batch statistics to use, in case
    # batchnorm is used and we checkpoint the batchnorm stats.
    mnet_kwargs = {}
    if mnet.batchnorm_layers is not None:
        if config.bn_distill_stats:
            raise NotImplementedError()
        elif not config.bn_no_running_stats and \
                not config.bn_no_stats_checkpointing:
            # Specify current task as condition to select correct
            # running stats.
            mnet_kwargs['condition'] = task_id

    ######################
    ### Start training ###
    ######################

    iter_per_epoch = -1
    if config.epochs == -1:
        training_iterations = config.n_iter
    else:
        assert (config.epochs > 0)
        iter_per_epoch = int(np.ceil(data.num_train_samples / \
                                     config.batch_size))
        training_iterations = config.epochs * iter_per_epoch

    summed_iter_runtime = 0

    for i in range(training_iterations):
        ### Evaluate network.
        # We test the network before we run the training iteration.
        # That way, we can see the initial performance of the untrained network.
        if i % config.val_iter == 0:
            test(task_id,
                 data,
                 mnet,
                 hnet,
                 device,
                 shared,
                 config,
                 writer,
                 logger,
                 train_iter=i)
            mnet.train()
            if hnet is not None:
                hnet.train()

        if i % 200 == 0:
            logger.info('Training step: %d ...' % i)

        iter_start_time = time()

        theta_optimizer.zero_grad()
        if emb_optimizer is not None:
            emb_optimizer.zero_grad()

        #######################################
        ### Data for current task and batch ###
        #######################################
        batch = data.next_train_batch(config.batch_size)
        X = data.input_to_torch_tensor(batch[0], device, mode='train')
        T = data.output_to_torch_tensor(batch[1], device, mode='train')

        # Get the output neurons depending on the continual learning scenario.
        n_y = data.num_classes
        if config.cl_scenario == 1:
            # Choose current head.
            task_out = [task_id * n_y, (task_id + 1) * n_y]
        elif config.cl_scenario == 2:
            # Always all output neurons, only one head is used.
            task_out = [0, n_y]
        else:
            # Choose current head, which will be inferred during inference.
            task_out = [task_id * n_y, (task_id + 1) * n_y]

        ########################
        ### Loss computation ###
        ########################
        if config.mnet_only:
            weights = None
        else:
            weights = hnet.forward(task_id=task_id)
        Y_hat_logits = mnet.forward(X, weights, **mnet_kwargs)

        # Restrict output neurons
        Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]]
        assert (T.shape[1] == Y_hat_logits.shape[1])
        # compute loss on task and compute gradients
        if config.soft_targets:
            soft_label = 0.95
            num_classes = data.num_classes
            soft_targets = torch.where(
                T == 1, torch.Tensor([soft_label]),
                torch.Tensor([(1 - soft_label) / (num_classes - 1)]))
            soft_targets = soft_targets.to(device)
            loss_task = Classifier.softmax_and_cross_entropy(
                Y_hat_logits, soft_targets)
        else:
            loss_task = Classifier.logit_cross_entropy_loss(Y_hat_logits, T)

        # Compute gradients based on task loss (those might be used in the CL
        # regularizer).
        loss_task.backward(retain_graph=calc_reg, create_graph=calc_reg and \
                           config.backprop_dt)

        # The current task embedding only depends in the task loss, so we can
        # update it already.
        if emb_optimizer is not None:
            emb_optimizer.step()

        #############################
        ### CL (HNET) Regularizer ###
        #############################
        loss_reg = 0
        dTheta = None

        if calc_reg:
            if config.no_lookahead:
                dTembs = None
                dTheta = None
            else:
                dTheta = opstep.calc_delta_theta(
                    theta_optimizer,
                    False,
                    lr=config.lr,
                    detach_dt=not config.backprop_dt)

                if config.continue_emb_training:
                    dTembs = dTheta[-task_id:]
                    dTheta = dTheta[:-task_id]
                else:
                    dTembs = None

            loss_reg = hreg.calc_fix_target_reg(
                hnet,
                task_id,
                targets=targets_hypernet,
                dTheta=dTheta,
                dTembs=dTembs,
                mnet=mnet,
                inds_of_out_heads=regged_outputs,
                prev_theta=prev_theta,
                prev_task_embs=prev_task_embs,
                batch_size=config.cl_reg_batch_size)

            loss_reg *= config.beta

            loss_reg.backward()

        # Now, that we computed the regularizer, we can use the accumulated
        # gradients and update the hnet (or mnet) parameters.
        theta_optimizer.step()

        Y_hat = F.softmax(Y_hat_logits, dim=1)
        classifier_accuracy = Classifier.accuracy(Y_hat, T) * 100.0

        # print('train T: ',Y_hat.argmax(dim=1, keepdim=False))
        # print('train T: ',T.argmax(dim=1, keepdim=False))
        # print('train Y_hat: ',Y_hat.size())
        # print('train T: ',T.size())

        #########################
        # Learning rate scheduler
        #########################
        if config.plateau_lr_scheduler:
            assert (iter_per_epoch != -1)
            if i % iter_per_epoch == 0 and i > 0:
                curr_epoch = i // iter_per_epoch
                logger.info('Computing test accuracy for plateau LR ' +
                            'scheduler (epoch %d).' % curr_epoch)
                # We need a validation quantity for the plateau LR scheduler.
                # FIXME we should use an actual validation set rather than the
                # test set.
                # Note, https://keras.io/examples/cifar10_resnet/ uses the test
                # set to compute the validation loss. We use the "validation"
                # accuracy instead.
                # FIXME We increase `train_iter` as the print messages in the
                # test method suggest that the testing has been executed before
                test_acc, _ = test(task_id,
                                   data,
                                   mnet,
                                   hnet,
                                   device,
                                   shared,
                                   config,
                                   writer,
                                   logger,
                                   train_iter=i + 1)
                mnet.train()
                if hnet is not None:
                    hnet.train()

                plateau_scheduler_theta.step(test_acc)
                if plateau_scheduler_emb is not None:
                    plateau_scheduler_emb.step(test_acc)

        if config.lambda_lr_scheduler:
            assert (iter_per_epoch != -1)
            if i % iter_per_epoch == 0 and i > 0:
                curr_epoch = i // iter_per_epoch
                logger.info('Applying Lambda LR scheduler (epoch %d).' %
                            curr_epoch)

                lambda_scheduler_theta.step()
                if lambda_scheduler_emb is not None:
                    lambda_scheduler_emb.step()

        ###########################
        ### Tensorboard summary ###
        ###########################
        # We don't wanna slow down training by having too much output.
        if i % 50 == 0:
            writer.add_scalar('train/task_%d/class_accuracy' % task_id,
                              classifier_accuracy, i)
            writer.add_scalar('train/task_%d/loss_task' % task_id, loss_task,
                              i)
            writer.add_scalar('train/task_%d/loss_reg' % task_id, loss_reg, i)

        ### Show the current training progress to the user.
        if i % config.val_iter == 0:
            msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \
                  '(on current training batch).'
            logger.debug(msg.format(i, classifier_accuracy))

        iter_end_time = time()
        summed_iter_runtime += (iter_end_time - iter_start_time)

        if i % 200 == 0:
            logger.info('Training step: %d ... Done -- (runtime: %f sec)' % \
                        (i, iter_end_time - iter_start_time))

    if mnet.batchnorm_layers is not None:
        if not config.bn_distill_stats and \
                not config.bn_no_running_stats and \
                not config.bn_no_stats_checkpointing:
            # Checkpoint the current running statistics (that have been
            # estimated while training the current task).
            for bn_layer in mnet.batchnorm_layers:
                assert (bn_layer.num_stats == task_id + 1)
                bn_layer.checkpoint_stats()

    avg_iter_time = summed_iter_runtime / config.n_iter
    logger.info('Average runtime per training iteration: %f sec.' % \
                avg_iter_time)

    logger.info('Elapsed time for training task %d: %f sec.' % \
                (task_id+1, time()-start_time))
Beispiel #2
0
def train_gan_one_t(dhandler, dis, gen, g_hnet, device, config, writer,
                    embd_list, t):
    """ Train the conditional MNIST GAN for one task.
    In this function the main training logic for this replay model is 
    implemented. After setting the optimizers for the discriminator/generator 
    and it's hypernetwork if applicable, a standart variational autoencoder 
    training scheme is implemented. To prevent the generator (its hypernetwork) 
    from forgetting, we add our hypernetwork regularisation term for all tasks 
    seen before ``t`` to the vae loss. 

    Args:
        (....): See docstring of function 
            :func:`mnist.replay.train_replay.train`.
        embd_list: Helper list of lists for embedding plotting.
        t: Task id to train.
    """

    print("Training GAN on data handler: ", t)

    # get lists for plotting embeddings
    d_embeddings, g_embeddings, d_embedding_history, g_embedding_history = \
                                                                    embd_list[:]
    # set training_iterations if epochs are set
    if config.epochs == -1:
        training_iterations = config.n_iter
    else:
        assert (config.epochs > 0)
        training_iterations = config.epochs * \
        int(np.ceil(dhandler.num_train_samples / config.batch_size))

    # Here we adjust the number of training iterations when we train our replay
    # method to replay every single class in a task given that condition.
    # We need to adjust the training iterations such that we train every
    # class in the task only a portion of the time we are given for the
    # whole task:
    # Training_time_per_class = training_time_per_task / num_class_per_task
    # This is important to compare to related work, as they set the training
    # time per task which we now have to split up.

    if config.single_class_replay:
        training_iterations = int(training_iterations / config.out_dim)

    # if we want to start training the new task with the weights of the previous
    # task we have to set the start embedding for the new task to the embedding
    # of the previous task.
    if config.embedding_reset == "old_embedding" and t > 0:
        if g_hnet is not None:
            last_emb = g_hnet.get_task_embs()[t - 1].detach().clone()
            g_hnet.get_task_embs()[t].data = last_emb

    # Compute targets for the hnet before training.
    if t > 0:
        if config.rp_beta > 0 and g_hnet is not None:
            targets_G = hreg.get_current_targets(t, g_hnet)
        else:
            targets_G = None

    ############
    # OPTIMIZERS
    ############

    # discriminator optimizer
    dis_paras = dis.parameters()
    doptimizer = optim.Adam(dis_paras, lr=config.enc_lr, betas=(0.9, 0.999))

    # discriminator optimizer (hnet or weights directly)
    if g_hnet is not None:
        g_paras = list(g_hnet.theta)
        if not config.dont_train_rp_embeddings:
            # Set the embedding optimizer only for the current task embedding.
            # Note that we could here continue training the old embeddings.
            g_emb_optimizer = optim.Adam([g_hnet.get_task_emb(t)],
                                         lr=config.dec_lr_emb,
                                         betas=(0.9, 0.999))
        else:
            g_emb_optimizer = None
    else:
        g_emb_optimizer = None
        g_paras = gen.parameters()

    goptimizer = optim.Adam(g_paras, lr=config.dec_lr, betas=(0.9, 0.999))

    calc_reg = config.rp_beta > 0 and t > 0 and g_hnet is not None

    for i in range(training_iterations):
        ### Test network.
        # We test the network before we run the training iteration.
        # That way, we can see the initial performance of the untrained net.
        if i % config.val_iter == 0:
            test(dis, gen, g_hnet, device, config, writer, i, t)
            gen.train()
            dis.train()
            if g_hnet is not None:
                g_hnet.train()

        if i % 100 == 0:
            print('Training iteration: %d.' % i)

        if config.show_plots:
            if g_hnet is not None:
                if (not config.no_cuda):
                    g_embedding_history.append(
                        g_hnet.get_task_emb(t).clone().detach().cpu().numpy())
                else:
                    g_embedding_history.append(
                        g_hnet.get_task_emb(t).clone().detach().numpy())

        #######
        # DATA
        #######
        real_batch = dhandler.next_train_batch(config.batch_size)
        X_real = dhandler.input_to_torch_tensor(real_batch[0],
                                                device,
                                                mode='train')
        #shift data in range [-1, 1] so we can tanh the output of G
        X_real = X_real * 2 - 1.0

        ######################
        # TRAIN DISCRIMINATOR
        ######################

        # set gradients again to zero
        doptimizer.zero_grad()
        goptimizer.zero_grad()
        if g_emb_optimizer is not None:
            g_emb_optimizer.zero_grad()

        # Note that X_fake is not normalize between 0 and 1
        # but like in in https://github.com/Zackory/Kera
        # s-MNIST-GAN/blob/master/mnist_gan.py
        # inputs are shiftet between [-1, 1] and X_fake is put through tanh
        #X_fake = torch.tanh(X_fake)
        X_fake = sample(gen, g_hnet, config, t, device)

        fake = dis.forward(X_fake)
        real = dis.forward(X_real)

        # compute discriminator loss
        dloss = gan_helpers.dis_loss(real, fake, config.loss_fun)

        # compute gradients for discriminator and take gradient step
        dloss.backward()
        doptimizer.step()

        ######################
        # TRAIN GENERATOR
        ######################

        # set gradients again to zero
        goptimizer.zero_grad()
        doptimizer.zero_grad()
        if g_emb_optimizer is not None:
            g_emb_optimizer.zero_grad()

        X_fake = sample(gen, g_hnet, config, t, device)
        fake = dis.forward(X_fake)

        # compute generator loss
        gloss = gan_helpers.gen_loss(fake, config.loss_fun)

        gloss.backward(retain_graph=calc_reg,create_graph=calc_reg and \
                           config.backprop_dt)

        # compute hypernet reg loss and fix embedding->change current embs
        if calc_reg:
            if config.no_lookahead:
                dTheta = None
            else:
                dTheta = opstep.calc_delta_theta(
                    goptimizer,
                    config.use_sgd_change,
                    lr=config.dec_lr,
                    detach_dt=not config.backprop_dt)

            gloss_reg = config.rp_beta * hreg.calc_fix_target_reg(
                g_hnet,
                t,
                targets=targets_G,
                mnet=gen,
                dTheta=dTheta,
                dTembs=None)
            gloss_reg.backward()
        else:
            gloss_reg = 0

        # compute gradients for generator and take gradient step
        goptimizer.step()
        if g_hnet is not None and not config.dont_train_rp_embeddings:
            g_emb_optimizer.step()

        # Visualization of current progress in tensorboard
        if i % config.plot_update_steps == 0 and i > 0 and config.show_plots:
            if d_embedding_history is not None:
                d_embedding_cut = np.asarray(d_embedding_history[2:])
            else:
                d_embedding_cut = None
            if g_embedding_history is not None:
                g_embedding_cut = np.asarray(g_embedding_history[2:])
            else:
                g_embedding_cut = None
            _viz_training(X_real,
                          X_fake,
                          g_embeddings,
                          d_embeddings,
                          g_embedding_cut,
                          d_embedding_cut,
                          writer,
                          i,
                          config,
                          title="train_cond_" + str(t))

        # track some training statistics
        writer.add_scalar('train/gen_loss_%d' % (t), gloss + gloss_reg, i)
        writer.add_scalar('train/dloss_all_%d' % (t), dloss, i)
        writer.add_scalar('train/dis_accuracy_%d' % (t),
                          gan_helpers.accuracy(real, fake, config.loss_fun), i)
        if config.rp_beta > 0:
            writer.add_scalar('train/g_hnet_loss_reg_%d' % (t), gloss_reg, i)
            writer.add_scalar('train/g_loss_only_%d' % (t), gloss, i)

    test(dis, gen, g_hnet, device, config, writer, config.n_iter, t)
Beispiel #3
0
def train_proximal(task_id, data, mnet, hnet, device, config, writer):
    r"""Train the hypernetwork via a proximal algorithm. Hence, we don't optimize
    the weights of the hypernetwork directly (except for the task embeddings).
    Instead, we optimize the following loss for dTheta. After a few optimization
    steps, dTheta will be added to the current set of weights in the
    hypernetwork.

    .. math::
        \text{loss} = \text{task\_loss}(\theta + \Delta\theta) +
                      \alpha \lVert \Delta\theta \rVert^2 +
                      \beta *  \sum_{j < \text{task\_id}} \lVert
                          h(c_j, \theta) - h(c_j, \theta + \Delta\theta)
                          \rVert^2

    Args:
        (....): See docstring of method :func:`train_reg`.
    """
    if config.reg == 3 or config.ewc_weight_importance:
        # TODO Don't just copy all the code, find a more elegant solution.
        raise NotImplementedError('Chosen regularizer not implemented for ' +
                                  'proximal algorithm!')
    if config.plastic_prev_tembs:
        # TODO can be implemented as above.
        raise NotImplementedError('Option "plastic_prev_tembs" not yet ' +
                                  'implemented for proximal algorithm.')

    print('Training network ...')

    mnet.train()
    hnet.train()

    regged_outputs = None
    if config.multi_head:
        n_y = data.out_shape[0]
        out_head_inds = [list(range(i * n_y, (i + 1) * n_y)) for i in
                         range(task_id + 1)]
        # Outputs to be regularized.
        regged_outputs = out_head_inds[:-1] if config.masked_reg else None
    allowed_outputs = out_head_inds[task_id] if config.multi_head else None

    # Regularizer targets.
    if config.reg == 0:
        targets = hreg.get_current_targets(task_id, hnet)

    # Generate dTheta
    dTheta = nn.ParameterList()
    for tshape in hnet.theta_shapes:
        dTheta.append(nn.Parameter(torch.Tensor(*tshape),
                                   requires_grad=True))
    dTheta = dTheta.to(device)

    # Initialize dTheta
    for dt in dTheta:
        dt.data.zero_()

    dtheta_optimizer = optim.Adam(dTheta, lr=config.lr_hyper)
    # We only optimize the task embedding corresponding to the current task,
    # the remaining ones stay constant.
    emb_optimizer = optim.Adam([hnet.get_task_emb(task_id)],
                               lr=config.lr_hyper)

    for i in range(config.n_iter):
        ### Evaluate network.
        # We test the network before we run the training iteration.
        # That way, we can see the initial performance of the untrained network.
        if i % config.val_iter == 0:
            evaluate(task_id, data, mnet, hnet, device, config, writer, i)
            mnet.train()
            hnet.train()

        if i % 100 == 0:
            print('Training iteration: %d.' % i)

        batch = data.next_train_batch(config.batch_size)
        X = data.input_to_torch_tensor(batch[0], device, mode='train')
        T = data.output_to_torch_tensor(batch[1], device, mode='train')

        ### Train theta.
        # Initialize dTheta
        # n_steps has to be high, if we don't do this reset.
        for dt in dTheta:
            dt.data.zero_()

        # Reset optimizer state in every new iteration:
        # This only seems to hurt, even if dTheta is reset to zero every
        # training iteration.
        # dtheta_optimizer = optim.Adam(dTheta, lr=config.lr_hyper)

        # Train dTheta
        dT_loss_vals = []  # For visualizations.
        for n in range(config.n_steps):
            dtheta_optimizer.zero_grad()

            weights = hnet.forward(task_id, dTheta=dTheta)
            Y = mnet.forward(X, weights)
            if config.multi_head:
                Y = Y[:, allowed_outputs]

            # Task-specific loss.
            loss_task = F.mse_loss(Y, T)

            # L2 reg for dTheta
            dTheta_norm = torch.norm(torch.cat([d.view(-1) for d in dTheta]))
            l2_reg = dTheta_norm

            # Continual learning regularizer.
            cl_reg = torch.zeros(())  # Scalar 0
            if task_id > 0 and config.beta > 0:
                if config.reg == 0:
                    cl_reg = hreg.calc_fix_target_reg(hnet, task_id,
                                                      targets=targets, dTheta=dTheta, mnet=mnet,
                                                      inds_of_out_heads=regged_outputs)
                elif config.reg == 1:
                    cl_reg = hreg.calc_value_preserving_reg(hnet, task_id,
                                                            dTheta)
                elif config.reg == 2:
                    cl_reg = hreg.calc_jac_reguarizer(hnet, task_id, dTheta,
                                                      device)

            loss = loss_task + config.alpha * l2_reg + config.beta * cl_reg
            loss.backward()
            dtheta_optimizer.step()

            dT_loss_vals.append([l.data.cpu().numpy() for l in
                                 [loss_task, l2_reg, cl_reg, loss]])

        # Apply dTheta.
        for tind, t in enumerate(hnet.theta):
            t.data = t.data + dTheta[tind].data

        ### Train class embedding.
        emb_optimizer.zero_grad()

        weights = hnet.forward(task_id)
        Y = mnet.forward(X, weights)
        if config.multi_head:
            Y = Y[:, allowed_outputs]

        loss_mse = F.mse_loss(Y, T)
        loss_mse.backward()
        emb_optimizer.step()

        if i % 10 == 0:
            writer.add_scalar('train/task_%d/mse_loss' % task_id, loss_mse, i)
            dT_norm = torch.norm(torch.cat([d.view(-1) for d in dTheta]))
            writer.add_scalar('train/task_%d/dT_norm' % task_id, dT_norm, i)

            # We visualize the evolution of dTheta learning by looking at
            # individual timesteps (because I don't know how to visualize
            # the evolution of sequences over time in Tensorboard).
            if config.n_steps == 1:
                inds = [0]
            elif config.n_steps == 2:
                inds = [0, config.n_steps - 1]
            else:
                inds = [0, config.n_steps // 2, config.n_steps - 1]

            for ii in inds:
                ltask, ll2, lcl, l = dT_loss_vals[ii]
                writer.add_scalar('train/task_%d/dT_step_%d/mse' % \
                                  (task_id, ii), ltask, i)
                writer.add_scalar('train/task_%d/dT_step_%d/dT_l2_reg' % \
                                  (task_id, ii), ll2, i)
                writer.add_scalar('train/task_%d/dT_step_%d/dT_cl_reg' % \
                                  (task_id, ii), lcl, i)
                writer.add_scalar('train/task_%d/dT_step_%d/dT_full_loss' % \
                                  (task_id, ii), l, i)

    print('Training network ... Done')
Beispiel #4
0
def train_reg(task_id, data, mnet, hnet, device, config, writer):
    r"""Train the network using the task-specific loss plus a regularizer that
    should weaken catastrophic forgetting.

    .. math::
        \text{loss} = \text{task\_loss} + \beta * \text{regularizer}

    Args:
        task_id: The index of the task on which we train.
        data: The dataset handler.
        mnet: The model of the main network.
        hnet: The model of the hyoer network.
        device: Torch device (cpu or gpu).
        config: The command line arguments.
        writer: The tensorboard summary writer.
    """
    print('Training network ...')

    mnet.train()
    hnet.train()

    regged_outputs = None
    if config.multi_head:
        n_y = data.out_shape[0]
        out_head_inds = [list(range(i * n_y, (i + 1) * n_y)) for i in
                         range(task_id + 1)]
        # Outputs to be regularized.
        regged_outputs = out_head_inds[:-1] if config.masked_reg else None
    allowed_outputs = out_head_inds[task_id] if config.multi_head else None

    # Collect Fisher estimates for the reg computation.
    fisher_ests = None
    if config.ewc_weight_importance and task_id > 0:
        fisher_ests = []
        n_W = len(hnet.target_shapes)
        for t in range(task_id):
            ff = []
            for i in range(n_W):
                _, buff_f_name = ewc._ewc_buffer_names(t, i, False)
                ff.append(getattr(mnet, buff_f_name))
            fisher_ests.append(ff)

    # Regularizer targets.
    if config.reg == 0 and config.beta > 0:
        targets = hreg.get_current_targets(task_id, hnet)

    regularized_params = list(hnet.theta)
    if task_id > 0 and config.plastic_prev_tembs:
        assert (config.reg == 0)
        for i in range(task_id):  # for all previous task embeddings
            regularized_params.append(hnet.get_task_emb(i))
    theta_optimizer = optim.Adam(regularized_params, lr=config.lr_hyper)
    # We only optimize the task embedding corresponding to the current task,
    # the remaining ones stay constant.
    emb_optimizer = optim.Adam([hnet.get_task_emb(task_id)],
                               lr=config.lr_hyper)

    # Whether the regularizer will be computed during training?
    calc_reg = task_id > 0 and config.beta > 0

    for i in range(config.n_iter):
        ### Evaluate network.
        # We test the network before we run the training iteration.
        # That way, we can see the initial performance of the untrained network.
        if i % config.val_iter == 0:
            evaluate(task_id, data, mnet, hnet, device, config, writer, i)
            mnet.train()
            hnet.train()

        if i % 100 == 0:
            print('Training iteration: %d.' % i)

        ### Train theta and task embedding.
        theta_optimizer.zero_grad()
        emb_optimizer.zero_grad()

        batch = data.next_train_batch(config.batch_size)
        X = data.input_to_torch_tensor(batch[0], device, mode='train')
        T = data.output_to_torch_tensor(batch[1], device, mode='train')

        weights = hnet.forward(task_id)
        Y = mnet.forward(X, weights)
        if config.multi_head:
            Y = Y[:, allowed_outputs]

        # Task-specific loss.
        loss_task = F.mse_loss(Y, T)
        # We already compute the gradients, to then be able to compute delta
        # theta.
        loss_task.backward(retain_graph=calc_reg,
                           create_graph=config.backprop_dt and calc_reg)

        # The task embedding is only trained on the task-specific loss.
        # Note, the gradients accumulated so far are from "loss_task".
        emb_optimizer.step()

        # DELETEME check correctness of opstep.calc_delta_theta.
        # dPrev = torch.cat([d.data.clone().view(-1) for d in hnet.theta])
        # dT_estimate = torch.cat([d.view(-1).clone() for d in
        #    opstep.calc_delta_theta(theta_optimizer,
        #                            config.use_sgd_change, lr=config.lr_hyper,
        #                            detach_dt=not config.backprop_dt)])

        loss_reg = 0
        dTheta = None
        grad_tloss = None
        if calc_reg:
            if i % 100 == 0:  # Just for debugging: displaying grad magnitude.
                grad_tloss = torch.cat([d.grad.clone().view(-1) for d in
                                        hnet.theta])

            dTheta = opstep.calc_delta_theta(theta_optimizer,
                                             config.use_sgd_change, lr=config.lr_hyper,
                                             detach_dt=not config.backprop_dt)
            if config.plastic_prev_tembs:
                dTembs = dTheta[-task_id:]
                dTheta = dTheta[:-task_id]
            else:
                dTembs = None

            if config.reg == 0:
                loss_reg = hreg.calc_fix_target_reg(hnet, task_id,
                                                    targets=targets, dTheta=dTheta, dTembs=dTembs, mnet=mnet,
                                                    inds_of_out_heads=regged_outputs,
                                                    fisher_estimates=fisher_ests)
            elif config.reg == 1:
                loss_reg = hreg.calc_value_preserving_reg(hnet, task_id, dTheta)
            elif config.reg == 2:
                loss_reg = hreg.calc_jac_reguarizer(hnet, task_id, dTheta,
                                                    device)
            elif config.reg == 3:  # EWC
                loss_reg = ewc.ewc_regularizer(task_id, hnet.theta, None,
                                               hnet=hnet, online=config.online_ewc, gamma=config.gamma)
            loss_reg *= config.beta

            loss_reg.backward()

            if grad_tloss is not None:
                grad_full = torch.cat([d.grad.view(-1) for d in hnet.theta])
                # Grad of regularizer.
                grad_diff = grad_full - grad_tloss
                grad_diff_norm = torch.norm(grad_diff, 2)

                # Cosine between regularizer gradient and task-specific
                # gradient.
                dT_vec = torch.cat([d.view(-1).clone() for d in dTheta])
                grad_cos = F.cosine_similarity(grad_diff.view(1, -1),
                                               dT_vec.view(1, -1))

        theta_optimizer.step()

        # DELETEME
        # dCurr = torch.cat([d.data.view(-1) for d in hnet.theta])
        # dT_actual = dCurr - dPrev
        # print(torch.norm(dT_estimate - dT_actual, 2))

        if i % 10 == 0:
            writer.add_scalar('train/task_%d/mse_loss' % task_id, loss_task, i)
            writer.add_scalar('train/task_%d/regularizer' % task_id, loss_reg,
                              i)
            writer.add_scalar('train/task_%d/full_loss' % task_id, loss_task +
                              loss_reg, i)
            if dTheta is not None:
                dT_norm = torch.norm(torch.cat([d.view(-1) for d in dTheta]), 2)
                writer.add_scalar('train/task_%d/dTheta_norm' % task_id,
                                  dT_norm, i)
            if grad_tloss is not None:
                writer.add_scalar('train/task_%d/full_grad_norm' % task_id,
                                  torch.norm(grad_full, 2), i)
                writer.add_scalar('train/task_%d/reg_grad_norm' % task_id,
                                  grad_diff_norm, i)
                writer.add_scalar('train/task_%d/cosine_task_reg' % task_id,
                                  grad_cos, i)

    if config.reg == 3:
        ## Estimate diagonal Fisher elements.
        ewc.compute_fisher(task_id, data, hnet.theta, device, mnet, hnet=hnet,
                           empirical_fisher=True, online=config.online_ewc, gamma=config.gamma,
                           n_max=config.n_fisher, regression=True,
                           allowed_outputs=allowed_outputs)

    if config.ewc_weight_importance:
        ## Estimate Fisher for outputs of the hypernetwork.
        weights = hnet.forward(task_id)

        # Note, there are actually no parameters in the main network.
        fake_main_params = nn.ParameterList()
        for i, W in enumerate(weights):
            fake_main_params.append(nn.Parameter(torch.Tensor(*W.shape),
                                                 requires_grad=True))
            fake_main_params[i].data = weights[i]

        ewc.compute_fisher(task_id, data, fake_main_params, device, mnet,
                           empirical_fisher=True, online=False, n_max=config.n_fisher,
                           regression=True, allowed_outputs=allowed_outputs)

    print('Training network ... Done')
Beispiel #5
0
def train_vae_one_t(dhandler, enc, dec, d_hnet, device, config, writer,
                    embd_list, t):
    """ Train the conditional MNIST VAE for one task.
    In this function the main training logic for this replay model is 
    implemented. After setting the optimizers for the encoder/decoder and it's
    hypernetwork if applicable, a standart variational autoencoder training
    scheme is implemented. To prevent the decoder (its hypernetwork) from 
    forgetting, we add our hypernetwork regularisation term for all tasks 
    seen before ``t`` to the vae loss. 

    Args:
        (....): See docstring of function :func:`train`.
        embd_list: Helper list of lists for embedding plotting.
        t: Task id that will be trained.

    """

    # set to training mode 
    enc.train()
    dec.train()
    if d_hnet is not None:
        d_hnet.train()

    # reset data handler
    print("Training VAE on data handler: ", t)

    # get lists for plotting embeddings
    enc_embs, dec_embs, enc_embs_history, dec_embs_history = embd_list[:]
    # set training_iterations if epochs are set
    if config.epochs == -1:
        training_iterations = config.n_iter
    else:
        assert (config.epochs > 0)
        training_iterations = config.epochs * \
                              int(np.ceil(dhandler.num_train_samples / config.batch_size))

    # Here we adjust the number of training iterations when we train our replay 
    # method to replay every single class in a task given that condition. 
    # We need to adjust the training iterations such that we train every 
    # class in the task only a portion of the time we are given for the 
    # whole task:
    # Training_time_per_class = training_time_per_task / num_class_per_task
    # This is important to compare to related work, as they set the training 
    # time per task which we now have to split up.

    if config.single_class_replay:
        training_iterations = int(training_iterations / config.out_dim)

    # if we want to start training the new task with the weights of the previous
    # task we have to set the start embedding for the new task to the embedding
    # of the previous task. 
    if config.embedding_reset == "old_embedding" and t > 0:
        if d_hnet is not None:
            last_emb = d_hnet.get_task_embs()[t - 1].detach().clone()
            d_hnet.get_task_embs()[t].data = last_emb

    # Compute targets for the hnet before training. 
    if t > 0:
        if config.rp_beta > 0 and d_hnet is not None:
            targets_D = hreg.get_current_targets(t, d_hnet)
        else:
            targets_D = None

    ############
    # OPTIMIZERS 
    ############

    # encoder optimizer
    e_paras = enc.parameters()
    eoptimizer = optim.Adam(e_paras, lr=config.enc_lr,
                            betas=(0.9, 0.999))

    # decoder optimizer (hnet or weights directly)
    if d_hnet is not None:
        d_paras = list(d_hnet.theta)
        if not config.dont_train_rp_embeddings:
            # Set the embedding optimizer only for the current task embedding.
            # Note that we could here continue training the old embeddings.
            d_emb_optimizer = optim.Adam([d_hnet.get_task_emb(t)],
                                         lr=config.dec_lr_emb, betas=(0.9, 0.999))
        else:
            d_emb_optimizer = None
    else:
        d_emb_optimizer = None
        d_paras = dec.parameters()

    doptimizer = optim.Adam(d_paras, lr=config.dec_lr,
                            betas=(0.9, 0.999))

    calc_reg = config.rp_beta > 0 and t > 0 and d_hnet is not None

    ###########
    # TRAINING 
    ###########

    for i in range(training_iterations):
        ### Test network.
        # We test the network before we run the training iteration.
        # That way, we can see the initial performance of the untrained net.
        if i % config.val_iter == 0:
            test(enc, dec, d_hnet, device, config, writer, i, t)
            enc.train()
            dec.train()
            if d_hnet is not None:
                d_hnet.train()

        if i % 100 == 0:
            print('Training iteration: %d.' % i)

        # Some code for plotting. 
        # We want to visualize the hnet embedding trajectories. 
        if config.show_plots:
            if d_hnet is not None:
                if (not config.no_cuda):
                    dec_embs_history.append(d_hnet.get_task_emb(t).
                                            clone().detach().cpu().numpy())
                else:
                    dec_embs_history.append(d_hnet.get_task_emb(t).
                                            clone().detach().numpy())

        #######
        # DATA 
        #######
        real_batch = dhandler.next_train_batch(config.batch_size)
        X_real = dhandler.input_to_torch_tensor(real_batch[0], device,
                                                mode='train')

        # set gradients again to zero
        eoptimizer.zero_grad()
        doptimizer.zero_grad()
        if d_emb_optimizer is not None:
            d_emb_optimizer.zero_grad()

        ############################
        # KLD + RECONSTRUCTION 
        ############################

        # feed data through encoder
        mu_var = enc.forward(X_real)
        mu = mu_var[:, 0: config.latent_dim]
        logvar = mu_var[:, config.latent_dim:2 * config.latent_dim]

        # compute KLD
        kld = compute_kld(mu, logvar, config, t)

        # sample from encoder gaussian distribution
        dec_input = reparameterize(mu, logvar)
        reconstructions = sample(dec, d_hnet, config, t, device, z=dec_input)
        # average reconstruction error like this to compare to related work, see
        # https://github.com/GMvandeVen/continual-learning/blob/master/train.py

        x_rec_loss = F.binary_cross_entropy(reconstructions,
                                            X_real, reduction='none')
        x_rec_loss = torch.mean(x_rec_loss, dim=1)
        x_rec_loss = torch.mean(x_rec_loss)

        loss = x_rec_loss + kld

        ######################################################
        # HYPERNET REGULARISATION - CONTINUAL LEARNING METHOD
        ######################################################

        loss.backward(retain_graph=calc_reg, create_graph=calc_reg and \
                                                          config.backprop_dt)

        # compute hypernet loss and fix embedding -> change current embs
        if calc_reg:
            if config.no_lookahead:
                dTheta = None
            else:
                dTheta = opstep.calc_delta_theta(doptimizer,
                                                 config.use_sgd_change, lr=config.dec_lr,
                                                 detach_dt=not config.backprop_dt)
            dloss_reg = config.rp_beta * hreg.calc_fix_target_reg(d_hnet, t,
                                                                  targets=targets_D,
                                                                  mnet=dec, dTheta=dTheta, dTembs=None)
            dloss_reg.backward()
        else:
            dloss_reg = 0

        # compute gradients for generator and take gradient step
        doptimizer.step()
        eoptimizer.step()
        if d_hnet is not None and not config.dont_train_rp_embeddings:
            d_emb_optimizer.step()

        # Visualization of current progress in tensorboard
        if (i % config.plot_update_steps == 0 and i > 0 and config.show_plots):
            if dec_embs_history is not None:
                dec_embedding_cut = np.asarray(dec_embs_history[2:])
            else:
                dec_embedding_cut = None
            if enc_embs_history is not None:
                enc_embedding_cut = np.asarray(enc_embs_history[2:])
            else:
                enc_embedding_cut = None
            _viz_training(X_real, reconstructions, enc_embs,
                          dec_embs, enc_embedding_cut, dec_embedding_cut,
                          writer, i, config, title="train_cond_" + str(t))

        # track some training statistics
        writer.add_scalar('train/kld_%d' % (t), kld, i)
        writer.add_scalar('train/reconstruction_%d' % (t), x_rec_loss, i)
        writer.add_scalar('train/all_loss_%d' % (t), loss + dloss_reg, i)
        if config.rp_beta > 0:
            writer.add_scalar('train/d_hnet_loss_reg_%d' % (t), dloss_reg, i)

    test(enc, dec, d_hnet, device, config, writer, config.n_iter, t)
Beispiel #6
0
def train_class_one_t(dhandler_class, dhandlers_rp, dec, d_hnet, net,
                      device, config, writer, t):
    """Train continual learning experiments on MNIST dataset for one task.
    In this function the main training logic is implemented. 
    After setting the optimizers for the network and hypernetwork if 
    applicable, the training is structured as follows: 
    First, we get the a training batch of the current task. Depending on 
    the learning scenario, we choose output heads and build targets 
    accordingly. 
    Second, if ``t`` is greater than 1, we add a loss term concerning 
    predictions of replayed data. See :func:`get_fake_data_loss` for 
    details. Third, to protect the hypernetwork from forgetting, we add an 
    additional L2 loss term namely the difference between its current output 
    given an embedding and checkpointed targets.
    Finally, we track some training statistics.

    Args:
        (....): See docstring of function :func:`train_tasks`.
        t: Task id.
    """

    # if cl with task inference we have the classifier empowered with a hnet 
    if config.training_with_hnet:
        net_hnet = net[1]
        net = net[0]
        net.train()
        net_hnet.train()
        params_to_regularize = list(net_hnet.theta)
        optimizer = optim.Adam(params_to_regularize,
                               lr=config.class_lr, betas=(0.9, 0.999))

        c_emb_optimizer = optim.Adam([net_hnet.get_task_emb(t)],
                                     lr=config.class_lr_emb, betas=(0.9, 0.999))
    else:
        net.train()
        net_hnet = None
        optimizer = optim.Adam(net.parameters(),
                               lr=config.class_lr, betas=(0.9, 0.999))

    # dont train the replay model if available
    if dec is not None:
        dec.eval()
    if d_hnet is not None:
        d_hnet.eval()

    # compute targets if classifier is trained with hnet
    if t > 0 and config.training_with_hnet:
        if config.online_target_computation:
            # Compute targets for the regularizer whenever they are needed.
            # -> Computationally expensive.
            targets_C = None
            prev_theta = [p.detach().clone() for p in net_hnet.theta]
            prev_task_embs = [p.detach().clone() for p in \
                              net_hnet.get_task_embs()]
        else:
            # Compute targets for the regularizer once and keep them all in
            # memory -> Memory expensive.
            targets_C = hreg.get_current_targets(t, net_hnet)
            prev_theta = None
            prev_task_embs = None

    dhandler_class.reset_batch_generator()

    # make copy of network
    if t >= 1:
        net_copy = copy.deepcopy(net)

    # set training_iterations if epochs are set
    if config.epochs == -1:
        training_iterations = config.n_iter
    else:
        assert (config.epochs > 0)
        training_iterations = config.epochs * \
                              int(np.ceil(dhandler_class.num_train_samples / config.batch_size))

    if config.class_incremental:
        training_iterations = int(training_iterations / config.out_dim)

    # Whether we will calculate the regularizer.
    calc_reg = t > 0 and config.class_beta > 0 and config.training_with_hnet

    # set if we want the reg only computed for a subset of the  previous tasks
    if config.hnet_reg_batch_size != -1:
        hnet_reg_batch_size = config.hnet_reg_batch_size
    else:
        hnet_reg_batch_size = None

    for i in range(training_iterations):

        # set optimizer to zero
        optimizer.zero_grad()
        if net_hnet is not None:
            c_emb_optimizer.zero_grad()

        # Get real data
        real_batch = dhandler_class.next_train_batch(config.batch_size)
        X_real = dhandler_class.input_to_torch_tensor(real_batch[0], device,
                                                      mode='train')
        T_real = dhandler_class.output_to_torch_tensor(real_batch[1], device,
                                                       mode='train')

        if i % 100 == 0 and config.show_plots:
            fig_real = _plotImages(X_real, config)
            writer.add_figure('train_class_' + str(t) + '_real',
                              fig_real, global_step=i)

        #################################################
        # Choosing output heads and constructing targets
        ################################################# 

        # If we train a task inference net or class incremental learning we 
        # we construct a target for every single class/task
        if config.class_incremental or config.training_task_infer:
            # in the beginning of training, we look at two output neuron
            task_out = [0, t + 1]
            T_real = torch.zeros((config.batch_size, task_out[1])).to(device)
            T_real[:, task_out[1] - 1] = 1

        elif config.cl_scenario == 1 or config.cl_scenario == 2:
            if config.cl_scenario == 1:
                # take the task specific output neuron
                task_out = [t * config.out_dim, t * config.out_dim + config.out_dim]
            else:
                # always all output neurons, only one head is used
                task_out = [0, config.out_dim]
        else:
            # The number of output neurons is generic and can grow i.e. we
            # do not have to know the number of tasks before we start 
            # learning.
            if not config.infer_output_head:
                task_out = [0, (t + 1) * config.out_dim]
                T_real = torch.cat((torch.zeros((config.batch_size,
                                                 t * config.out_dim)).to(device),
                                    T_real), dim=1)
            # this is a special case where we will infer the task id by another 
            # neural network so we can train on the correct output head direclty
            # and use the infered output head to compute the prediction
            else:
                task_out = [t * config.out_dim, t * config.out_dim + config.out_dim]

        # compute loss of current data
        if config.training_with_hnet:
            weights_c = net_hnet.forward(t)
        else:
            weights_c = None

        Y_hat_logits = net.forward(X_real, weights_c)
        Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]]

        if config.soft_targets:
            soft_label = 0.95
            num_classes = T_real.shape[1]
            soft_targets = torch.where(T_real == 1,
                                       torch.Tensor([soft_label]).to(device),
                                       torch.Tensor([(1 - soft_label) / (num_classes - 1)]).to(device))
            soft_targets = soft_targets.to(device)
            loss_task = Classifier.softmax_and_cross_entropy(Y_hat_logits,
                                                             soft_targets)
        else:
            loss_task = Classifier.softmax_and_cross_entropy(Y_hat_logits, T_real)

        ############################
        # compute loss for fake data
        ############################

        # Get fake data (of all tasks up until now and merge into list)
        if t >= 1 and not config.training_with_hnet:
            fake_loss = get_fake_data_loss(dhandlers_rp, net, dec, d_hnet, device,
                                           config, writer, t, i, net_copy)
            loss_task = (1 - config.l_rew) * loss_task + config.l_rew * fake_loss

        loss_task.backward(retain_graph=calc_reg, create_graph=calc_reg and \
                                                               config.backprop_dt)

        # compute hypernet loss and fix embedding -> change current embs
        if calc_reg:
            if config.no_lookahead:
                dTheta = None
            else:
                dTheta = opstep.calc_delta_theta(optimizer,
                                                 config.use_sgd_change, lr=config.class_lr,
                                                 detach_dt=not config.backprop_dt)
            loss_reg = config.class_beta * hreg.calc_fix_target_reg(net_hnet, t,
                                                                    targets=targets_C, mnet=net, dTheta=dTheta,
                                                                    dTembs=None,
                                                                    prev_theta=prev_theta,
                                                                    prev_task_embs=prev_task_embs,
                                                                    batch_size=hnet_reg_batch_size)
            loss_reg.backward()

        # compute backward passloss_task.backward()
        if not config.dont_train_main_model:
            optimizer.step()

        if net_hnet is not None and config.train_class_embeddings:
            c_emb_optimizer.step()

        # same stats saving
        if i % 50 == 0:
            # compute accuracies for tracking
            Y_hat_logits = net.forward(X_real, weights_c)
            Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]]
            Y_hat = F.softmax(Y_hat_logits, dim=1)
            classifier_accuracy = Classifier.accuracy(Y_hat, T_real) * 100.0
            writer.add_scalar('train/task_%d/class_accuracy' % t,
                              classifier_accuracy, i)
            writer.add_scalar('train/task_%d/loss_task' % t,
                              loss_task, i)
            if t >= 1 and not config.training_with_hnet:
                writer.add_scalar('train/task_%d/fake_loss' % t,
                                  fake_loss, i)

        # plot some gradient statistics
        if i % 200 == 0:
            if not config.dont_train_main_model:
                total_norm = 0
                if config.training_with_hnet:
                    params = net_hnet.theta
                else:
                    params = net.parameters()

                for p in params:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
                total_norm = total_norm ** (1. / 2)
                # TODO write gradient histograms?
                writer.add_scalar('train/task_%d/main_params_grad_norms' % t,
                                  total_norm, i)

            if net_hnet is not None and config.train_class_embeddings:
                total_norm = 0
                for p in [net_hnet.get_task_emb(t)]:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
                total_norm = total_norm ** (1. / 2)
                writer.add_scalar('train/task_%d/hnet_emb_grad_norms' % t,
                                  total_norm, i)

        if i % 200 == 0:
            msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \
                  '(on current training batch).'
            print(msg.format(i, classifier_accuracy))