Beispiel #1
0
def train(task_id, data, mnet, hnet, device, config, shared, writer, logger):
    """Train the hyper network using the task-specific loss plus a regularizer
    that should overcome catastrophic forgetting.

    :code:`loss = task_loss + beta * regularizer`.

    Args:
        task_id: The index of the task on which we train.
        data: The dataset handler.
        mnet: The model of the main network.
        hnet: The model of the hyper network. May be ``None``.
        device: Torch device (cpu or gpu).
        config: The command line arguments.
        shared (argparse.Namespace): Set of variables shared between functions.
        writer: The tensorboard summary writer.
        logger: The logger that should be used rather than the print method.
    """
    start_time = time()

    print('data: ', data)
    print('data.num_classes: ', data.num_classes)
    print('data.num_train_samples: ', data.num_train_samples)

    logger.info('Training network ...')

    mnet.train()
    if hnet is not None:
        hnet.train()

    #################
    ### Optimizer ###
    #################
    # Define the optimizers used to train main network and hypernet.
    if hnet is not None:
        theta_params = list(hnet.theta)
        if config.continue_emb_training:
            for i in range(task_id):  # for all previous task embeddings
                theta_params.append(hnet.get_task_emb(i))

        # Only for the current task embedding.
        # Important that this embedding is in a different optimizer in case
        # we use the lookahead.
        emb_optimizer = get_optimizer([hnet.get_task_emb(task_id)],
                                      config.lr,
                                      momentum=config.momentum,
                                      weight_decay=config.weight_decay,
                                      use_adam=config.use_adam,
                                      adam_beta1=config.adam_beta1,
                                      use_rmsprop=config.use_rmsprop)
    else:
        theta_params = mnet.weights
        emb_optimizer = None

    theta_optimizer = get_optimizer(theta_params,
                                    config.lr,
                                    momentum=config.momentum,
                                    weight_decay=config.weight_decay,
                                    use_adam=config.use_adam,
                                    adam_beta1=config.adam_beta1,
                                    use_rmsprop=config.use_rmsprop)

    ################################
    ### Learning rate schedulers ###
    ################################
    if config.plateau_lr_scheduler:
        assert (config.epochs != -1)
        # The scheduler config has been taken from here:
        # https://keras.io/examples/cifar10_resnet/
        # Note, we use 'max' instead of 'min' as we look at accuracy rather
        # than validation loss!
        plateau_scheduler_theta = optim.lr_scheduler.ReduceLROnPlateau( \
            theta_optimizer, 'max', factor=np.sqrt(0.1), patience=5,
            min_lr=0.5e-6, cooldown=0)
        plateau_scheduler_emb = None
        if emb_optimizer is not None:
            plateau_scheduler_emb = optim.lr_scheduler.ReduceLROnPlateau( \
                emb_optimizer, 'max', factor=np.sqrt(0.1), patience=5,
                min_lr=0.5e-6, cooldown=0)

    if config.lambda_lr_scheduler:
        assert (config.epochs != -1)

        def lambda_lr(epoch):
            """Multiplicative Factor for Learning Rate Schedule.

            Computes a multiplicative factor for the initial learning rate based
            on the current epoch. This method can be used as argument
            ``lr_lambda`` of class :class:`torch.optim.lr_scheduler.LambdaLR`.

            The schedule is inspired by the Resnet CIFAR-10 schedule suggested
            here https://keras.io/examples/cifar10_resnet/.

            Args:
                epoch (int): The number of epochs

            Returns:
                lr_scale (float32): learning rate scale
            """
            lr_scale = 1.
            if epoch > 180:
                lr_scale = 0.5e-3
            elif epoch > 160:
                lr_scale = 1e-3
            elif epoch > 120:
                lr_scale = 1e-2
            elif epoch > 80:
                lr_scale = 1e-1
            return lr_scale

        lambda_scheduler_theta = optim.lr_scheduler.LambdaLR(
            theta_optimizer, lambda_lr)
        lambda_scheduler_emb = None
        if emb_optimizer is not None:
            lambda_scheduler_emb = optim.lr_scheduler.LambdaLR(
                emb_optimizer, lambda_lr)

    ##############################
    ### Prepare CL Regularizer ###
    ##############################
    # Whether we will calculate the regularizer.
    calc_reg = task_id > 0 and not config.mnet_only and config.beta > 0 and \
        not config.train_from_scratch

    # Compute targets when the reg is activated and we are not training
    # the first task
    if calc_reg:
        if config.online_target_computation:
            # Compute targets for the regularizer whenever they are needed.
            # -> Computationally expensive.
            targets_hypernet = None
            prev_theta = [p.detach().clone() for p in hnet.theta]
            prev_task_embs = [p.detach().clone() for p in hnet.get_task_embs()]
        else:
            # Compute targets for the regularizer once and keep them all in
            # memory -> Memory expensive.
            targets_hypernet = hreg.get_current_targets(task_id, hnet)
            prev_theta = None
            prev_task_embs = None

        # If we do not want to regularize all outputs (in a multi-head setup).
        # Note, we don't care whether output heads other than the current one
        # change.
        regged_outputs = None
        if config.cl_scenario != 2:
            # FIXME We assume here that all tasks have the same output size.
            n_y = data.num_classes
            regged_outputs = [
                list(range(i * n_y, (i + 1) * n_y)) for i in range(task_id)
            ]

    # We need to tell the main network, which batch statistics to use, in case
    # batchnorm is used and we checkpoint the batchnorm stats.
    mnet_kwargs = {}
    if mnet.batchnorm_layers is not None:
        if config.bn_distill_stats:
            raise NotImplementedError()
        elif not config.bn_no_running_stats and \
                not config.bn_no_stats_checkpointing:
            # Specify current task as condition to select correct
            # running stats.
            mnet_kwargs['condition'] = task_id

    ######################
    ### Start training ###
    ######################

    iter_per_epoch = -1
    if config.epochs == -1:
        training_iterations = config.n_iter
    else:
        assert (config.epochs > 0)
        iter_per_epoch = int(np.ceil(data.num_train_samples / \
                                     config.batch_size))
        training_iterations = config.epochs * iter_per_epoch

    summed_iter_runtime = 0

    for i in range(training_iterations):
        ### Evaluate network.
        # We test the network before we run the training iteration.
        # That way, we can see the initial performance of the untrained network.
        if i % config.val_iter == 0:
            test(task_id,
                 data,
                 mnet,
                 hnet,
                 device,
                 shared,
                 config,
                 writer,
                 logger,
                 train_iter=i)
            mnet.train()
            if hnet is not None:
                hnet.train()

        if i % 200 == 0:
            logger.info('Training step: %d ...' % i)

        iter_start_time = time()

        theta_optimizer.zero_grad()
        if emb_optimizer is not None:
            emb_optimizer.zero_grad()

        #######################################
        ### Data for current task and batch ###
        #######################################
        batch = data.next_train_batch(config.batch_size)
        X = data.input_to_torch_tensor(batch[0], device, mode='train')
        T = data.output_to_torch_tensor(batch[1], device, mode='train')

        # Get the output neurons depending on the continual learning scenario.
        n_y = data.num_classes
        if config.cl_scenario == 1:
            # Choose current head.
            task_out = [task_id * n_y, (task_id + 1) * n_y]
        elif config.cl_scenario == 2:
            # Always all output neurons, only one head is used.
            task_out = [0, n_y]
        else:
            # Choose current head, which will be inferred during inference.
            task_out = [task_id * n_y, (task_id + 1) * n_y]

        ########################
        ### Loss computation ###
        ########################
        if config.mnet_only:
            weights = None
        else:
            weights = hnet.forward(task_id=task_id)
        Y_hat_logits = mnet.forward(X, weights, **mnet_kwargs)

        # Restrict output neurons
        Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]]
        assert (T.shape[1] == Y_hat_logits.shape[1])
        # compute loss on task and compute gradients
        if config.soft_targets:
            soft_label = 0.95
            num_classes = data.num_classes
            soft_targets = torch.where(
                T == 1, torch.Tensor([soft_label]),
                torch.Tensor([(1 - soft_label) / (num_classes - 1)]))
            soft_targets = soft_targets.to(device)
            loss_task = Classifier.softmax_and_cross_entropy(
                Y_hat_logits, soft_targets)
        else:
            loss_task = Classifier.logit_cross_entropy_loss(Y_hat_logits, T)

        # Compute gradients based on task loss (those might be used in the CL
        # regularizer).
        loss_task.backward(retain_graph=calc_reg, create_graph=calc_reg and \
                           config.backprop_dt)

        # The current task embedding only depends in the task loss, so we can
        # update it already.
        if emb_optimizer is not None:
            emb_optimizer.step()

        #############################
        ### CL (HNET) Regularizer ###
        #############################
        loss_reg = 0
        dTheta = None

        if calc_reg:
            if config.no_lookahead:
                dTembs = None
                dTheta = None
            else:
                dTheta = opstep.calc_delta_theta(
                    theta_optimizer,
                    False,
                    lr=config.lr,
                    detach_dt=not config.backprop_dt)

                if config.continue_emb_training:
                    dTembs = dTheta[-task_id:]
                    dTheta = dTheta[:-task_id]
                else:
                    dTembs = None

            loss_reg = hreg.calc_fix_target_reg(
                hnet,
                task_id,
                targets=targets_hypernet,
                dTheta=dTheta,
                dTembs=dTembs,
                mnet=mnet,
                inds_of_out_heads=regged_outputs,
                prev_theta=prev_theta,
                prev_task_embs=prev_task_embs,
                batch_size=config.cl_reg_batch_size)

            loss_reg *= config.beta

            loss_reg.backward()

        # Now, that we computed the regularizer, we can use the accumulated
        # gradients and update the hnet (or mnet) parameters.
        theta_optimizer.step()

        Y_hat = F.softmax(Y_hat_logits, dim=1)
        classifier_accuracy = Classifier.accuracy(Y_hat, T) * 100.0

        # print('train T: ',Y_hat.argmax(dim=1, keepdim=False))
        # print('train T: ',T.argmax(dim=1, keepdim=False))
        # print('train Y_hat: ',Y_hat.size())
        # print('train T: ',T.size())

        #########################
        # Learning rate scheduler
        #########################
        if config.plateau_lr_scheduler:
            assert (iter_per_epoch != -1)
            if i % iter_per_epoch == 0 and i > 0:
                curr_epoch = i // iter_per_epoch
                logger.info('Computing test accuracy for plateau LR ' +
                            'scheduler (epoch %d).' % curr_epoch)
                # We need a validation quantity for the plateau LR scheduler.
                # FIXME we should use an actual validation set rather than the
                # test set.
                # Note, https://keras.io/examples/cifar10_resnet/ uses the test
                # set to compute the validation loss. We use the "validation"
                # accuracy instead.
                # FIXME We increase `train_iter` as the print messages in the
                # test method suggest that the testing has been executed before
                test_acc, _ = test(task_id,
                                   data,
                                   mnet,
                                   hnet,
                                   device,
                                   shared,
                                   config,
                                   writer,
                                   logger,
                                   train_iter=i + 1)
                mnet.train()
                if hnet is not None:
                    hnet.train()

                plateau_scheduler_theta.step(test_acc)
                if plateau_scheduler_emb is not None:
                    plateau_scheduler_emb.step(test_acc)

        if config.lambda_lr_scheduler:
            assert (iter_per_epoch != -1)
            if i % iter_per_epoch == 0 and i > 0:
                curr_epoch = i // iter_per_epoch
                logger.info('Applying Lambda LR scheduler (epoch %d).' %
                            curr_epoch)

                lambda_scheduler_theta.step()
                if lambda_scheduler_emb is not None:
                    lambda_scheduler_emb.step()

        ###########################
        ### Tensorboard summary ###
        ###########################
        # We don't wanna slow down training by having too much output.
        if i % 50 == 0:
            writer.add_scalar('train/task_%d/class_accuracy' % task_id,
                              classifier_accuracy, i)
            writer.add_scalar('train/task_%d/loss_task' % task_id, loss_task,
                              i)
            writer.add_scalar('train/task_%d/loss_reg' % task_id, loss_reg, i)

        ### Show the current training progress to the user.
        if i % config.val_iter == 0:
            msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \
                  '(on current training batch).'
            logger.debug(msg.format(i, classifier_accuracy))

        iter_end_time = time()
        summed_iter_runtime += (iter_end_time - iter_start_time)

        if i % 200 == 0:
            logger.info('Training step: %d ... Done -- (runtime: %f sec)' % \
                        (i, iter_end_time - iter_start_time))

    if mnet.batchnorm_layers is not None:
        if not config.bn_distill_stats and \
                not config.bn_no_running_stats and \
                not config.bn_no_stats_checkpointing:
            # Checkpoint the current running statistics (that have been
            # estimated while training the current task).
            for bn_layer in mnet.batchnorm_layers:
                assert (bn_layer.num_stats == task_id + 1)
                bn_layer.checkpoint_stats()

    avg_iter_time = summed_iter_runtime / config.n_iter
    logger.info('Average runtime per training iteration: %f sec.' % \
                avg_iter_time)

    logger.info('Elapsed time for training task %d: %f sec.' % \
                (task_id+1, time()-start_time))
Beispiel #2
0
def train_class_one_t(dhandler_class, dhandlers_rp, dec, d_hnet, net,
                      device, config, writer, t):
    """Train continual learning experiments on MNIST dataset for one task.
    In this function the main training logic is implemented. 
    After setting the optimizers for the network and hypernetwork if 
    applicable, the training is structured as follows: 
    First, we get the a training batch of the current task. Depending on 
    the learning scenario, we choose output heads and build targets 
    accordingly. 
    Second, if ``t`` is greater than 1, we add a loss term concerning 
    predictions of replayed data. See :func:`get_fake_data_loss` for 
    details. Third, to protect the hypernetwork from forgetting, we add an 
    additional L2 loss term namely the difference between its current output 
    given an embedding and checkpointed targets.
    Finally, we track some training statistics.

    Args:
        (....): See docstring of function :func:`train_tasks`.
        t: Task id.
    """

    # if cl with task inference we have the classifier empowered with a hnet 
    if config.training_with_hnet:
        net_hnet = net[1]
        net = net[0]
        net.train()
        net_hnet.train()
        params_to_regularize = list(net_hnet.theta)
        optimizer = optim.Adam(params_to_regularize,
                               lr=config.class_lr, betas=(0.9, 0.999))

        c_emb_optimizer = optim.Adam([net_hnet.get_task_emb(t)],
                                     lr=config.class_lr_emb, betas=(0.9, 0.999))
    else:
        net.train()
        net_hnet = None
        optimizer = optim.Adam(net.parameters(),
                               lr=config.class_lr, betas=(0.9, 0.999))

    # dont train the replay model if available
    if dec is not None:
        dec.eval()
    if d_hnet is not None:
        d_hnet.eval()

    # compute targets if classifier is trained with hnet
    if t > 0 and config.training_with_hnet:
        if config.online_target_computation:
            # Compute targets for the regularizer whenever they are needed.
            # -> Computationally expensive.
            targets_C = None
            prev_theta = [p.detach().clone() for p in net_hnet.theta]
            prev_task_embs = [p.detach().clone() for p in \
                              net_hnet.get_task_embs()]
        else:
            # Compute targets for the regularizer once and keep them all in
            # memory -> Memory expensive.
            targets_C = hreg.get_current_targets(t, net_hnet)
            prev_theta = None
            prev_task_embs = None

    dhandler_class.reset_batch_generator()

    # make copy of network
    if t >= 1:
        net_copy = copy.deepcopy(net)

    # set training_iterations if epochs are set
    if config.epochs == -1:
        training_iterations = config.n_iter
    else:
        assert (config.epochs > 0)
        training_iterations = config.epochs * \
                              int(np.ceil(dhandler_class.num_train_samples / config.batch_size))

    if config.class_incremental:
        training_iterations = int(training_iterations / config.out_dim)

    # Whether we will calculate the regularizer.
    calc_reg = t > 0 and config.class_beta > 0 and config.training_with_hnet

    # set if we want the reg only computed for a subset of the  previous tasks
    if config.hnet_reg_batch_size != -1:
        hnet_reg_batch_size = config.hnet_reg_batch_size
    else:
        hnet_reg_batch_size = None

    for i in range(training_iterations):

        # set optimizer to zero
        optimizer.zero_grad()
        if net_hnet is not None:
            c_emb_optimizer.zero_grad()

        # Get real data
        real_batch = dhandler_class.next_train_batch(config.batch_size)
        X_real = dhandler_class.input_to_torch_tensor(real_batch[0], device,
                                                      mode='train')
        T_real = dhandler_class.output_to_torch_tensor(real_batch[1], device,
                                                       mode='train')

        if i % 100 == 0 and config.show_plots:
            fig_real = _plotImages(X_real, config)
            writer.add_figure('train_class_' + str(t) + '_real',
                              fig_real, global_step=i)

        #################################################
        # Choosing output heads and constructing targets
        ################################################# 

        # If we train a task inference net or class incremental learning we 
        # we construct a target for every single class/task
        if config.class_incremental or config.training_task_infer:
            # in the beginning of training, we look at two output neuron
            task_out = [0, t + 1]
            T_real = torch.zeros((config.batch_size, task_out[1])).to(device)
            T_real[:, task_out[1] - 1] = 1

        elif config.cl_scenario == 1 or config.cl_scenario == 2:
            if config.cl_scenario == 1:
                # take the task specific output neuron
                task_out = [t * config.out_dim, t * config.out_dim + config.out_dim]
            else:
                # always all output neurons, only one head is used
                task_out = [0, config.out_dim]
        else:
            # The number of output neurons is generic and can grow i.e. we
            # do not have to know the number of tasks before we start 
            # learning.
            if not config.infer_output_head:
                task_out = [0, (t + 1) * config.out_dim]
                T_real = torch.cat((torch.zeros((config.batch_size,
                                                 t * config.out_dim)).to(device),
                                    T_real), dim=1)
            # this is a special case where we will infer the task id by another 
            # neural network so we can train on the correct output head direclty
            # and use the infered output head to compute the prediction
            else:
                task_out = [t * config.out_dim, t * config.out_dim + config.out_dim]

        # compute loss of current data
        if config.training_with_hnet:
            weights_c = net_hnet.forward(t)
        else:
            weights_c = None

        Y_hat_logits = net.forward(X_real, weights_c)
        Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]]

        if config.soft_targets:
            soft_label = 0.95
            num_classes = T_real.shape[1]
            soft_targets = torch.where(T_real == 1,
                                       torch.Tensor([soft_label]).to(device),
                                       torch.Tensor([(1 - soft_label) / (num_classes - 1)]).to(device))
            soft_targets = soft_targets.to(device)
            loss_task = Classifier.softmax_and_cross_entropy(Y_hat_logits,
                                                             soft_targets)
        else:
            loss_task = Classifier.softmax_and_cross_entropy(Y_hat_logits, T_real)

        ############################
        # compute loss for fake data
        ############################

        # Get fake data (of all tasks up until now and merge into list)
        if t >= 1 and not config.training_with_hnet:
            fake_loss = get_fake_data_loss(dhandlers_rp, net, dec, d_hnet, device,
                                           config, writer, t, i, net_copy)
            loss_task = (1 - config.l_rew) * loss_task + config.l_rew * fake_loss

        loss_task.backward(retain_graph=calc_reg, create_graph=calc_reg and \
                                                               config.backprop_dt)

        # compute hypernet loss and fix embedding -> change current embs
        if calc_reg:
            if config.no_lookahead:
                dTheta = None
            else:
                dTheta = opstep.calc_delta_theta(optimizer,
                                                 config.use_sgd_change, lr=config.class_lr,
                                                 detach_dt=not config.backprop_dt)
            loss_reg = config.class_beta * hreg.calc_fix_target_reg(net_hnet, t,
                                                                    targets=targets_C, mnet=net, dTheta=dTheta,
                                                                    dTembs=None,
                                                                    prev_theta=prev_theta,
                                                                    prev_task_embs=prev_task_embs,
                                                                    batch_size=hnet_reg_batch_size)
            loss_reg.backward()

        # compute backward passloss_task.backward()
        if not config.dont_train_main_model:
            optimizer.step()

        if net_hnet is not None and config.train_class_embeddings:
            c_emb_optimizer.step()

        # same stats saving
        if i % 50 == 0:
            # compute accuracies for tracking
            Y_hat_logits = net.forward(X_real, weights_c)
            Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]]
            Y_hat = F.softmax(Y_hat_logits, dim=1)
            classifier_accuracy = Classifier.accuracy(Y_hat, T_real) * 100.0
            writer.add_scalar('train/task_%d/class_accuracy' % t,
                              classifier_accuracy, i)
            writer.add_scalar('train/task_%d/loss_task' % t,
                              loss_task, i)
            if t >= 1 and not config.training_with_hnet:
                writer.add_scalar('train/task_%d/fake_loss' % t,
                                  fake_loss, i)

        # plot some gradient statistics
        if i % 200 == 0:
            if not config.dont_train_main_model:
                total_norm = 0
                if config.training_with_hnet:
                    params = net_hnet.theta
                else:
                    params = net.parameters()

                for p in params:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
                total_norm = total_norm ** (1. / 2)
                # TODO write gradient histograms?
                writer.add_scalar('train/task_%d/main_params_grad_norms' % t,
                                  total_norm, i)

            if net_hnet is not None and config.train_class_embeddings:
                total_norm = 0
                for p in [net_hnet.get_task_emb(t)]:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
                total_norm = total_norm ** (1. / 2)
                writer.add_scalar('train/task_%d/hnet_emb_grad_norms' % t,
                                  total_norm, i)

        if i % 200 == 0:
            msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \
                  '(on current training batch).'
            print(msg.format(i, classifier_accuracy))