Exemplo n.º 1
0
    def distill_loss_fct(config, X, Y_logits, T_soft_logits, data):
        assert np.all(np.equal(X.shape[:2], T_soft_logits.shape[:2]))
        # Note, targets and predictions might have different head sizes if a
        # growing softmax is used.
        assert np.all(np.equal(Y_logits.shape[:2], T_soft_logits.shape[:2]))

        # Disillation temperature.
        T = 2.

        n_digs = config.ssmnist_seq_len

        # Note, smnist samples have the end-of-sequence bit as last
        # timestep, the rest is padded. Since there are `n_digs` digits per
        # sample, we consider the last end-of-sequence digit to determine the
        # unpadded sequence length.
        eod_features = X[:, :, 3].cpu().numpy()
        seq_lengths = np.argsort(eod_features, axis=0)[-n_digs:].max(axis=0)
        inds = seq_lengths - 1
        inds[inds < 0] = 0

        # Only compute loss for last timestep.
        Y_logits = Y_logits[inds, np.arange(inds.size), :]
        T_soft_logits = T_soft_logits[inds, np.arange(inds.size), :]

        target_mapping = None
        if config.all_task_softmax:
            target_mapping = list(range(T_soft_logits.shape[1]))

        return Classifier.knowledge_distillation_loss(
            Y_logits,
            T_soft_logits,
            target_mapping=target_mapping,
            device=Y_logits.device,
            T=T)
Exemplo n.º 2
0
    def distill_loss_fct(config,
                         X,
                         Y_logits,
                         T_soft_logits,
                         data,
                         in_seq_lens=None):
        if in_seq_lens is None:
            raise NotImplementedError(
                'This distillation loss is currently ' +
                'only implemented if sequence lengths are provided, as they ' +
                'can\'t be inferred easily.')
        # Note, input and output sequence lengths are identical for the PoS
        # dataset.

        assert np.all(np.equal(X.shape[:2], T_soft_logits.shape[:2]))
        # Note, targets and predictions might have different head sizes if a
        # growing softmax is used.
        assert np.all(np.equal(Y_logits.shape[:2], T_soft_logits.shape[:2]))

        # Disillation temperature.
        T = 2.

        target_mapping = None
        if config.all_task_softmax:
            target_mapping = list(range(T_soft_logits.shape[2]))

        dloss = 0
        total_num_ts = 0

        for bid in range(X.shape[1]):
            sl = int(in_seq_lens[bid])
            total_num_ts += sl

            Y_logits_i = Y_logits[:sl, bid, :]
            T_soft_logits_i = T_soft_logits[:sl, bid, :]

            dloss += Classifier.knowledge_distillation_loss(
                Y_logits_i,
                T_soft_logits_i,
                target_mapping=target_mapping,
                device=Y_logits.device,
                T=T) * sl

        return dloss / total_num_ts
Exemplo n.º 3
0
    def distill_loss_fct(config, X, Y_logits, T_soft_logits, data):
        # Note, targets and predictions might have different head sizes if a
        # growing softmax is used.
        assert np.all(np.equal(Y_logits.shape[:2], T_soft_logits.shape[:2]))

        # Only compute loss for last timestep.
        Y_logits = Y_logits[-1, :, :]
        T_soft_logits = T_soft_logits[-1, :, :]

        target_mapping = None
        if config.all_task_softmax:
            target_mapping = list(range(T_soft_logits.shape[1]))

        return Classifier.knowledge_distillation_loss(
            Y_logits,
            T_soft_logits,
            target_mapping=target_mapping,
            device=Y_logits.device,
            T=2.)
Exemplo n.º 4
0
def test(task_id,
         data,
         mnet,
         hnet,
         device,
         shared,
         config,
         writer,
         logger,
         train_iter=None,
         task_emb=None,
         cl_scenario=None,
         test_size=None):
    """Evaluate the current performance using the test set.

    Note:
        The hypernetwork ``hnet`` may be ``None``, in which case it is assumed
        that the main network ``mnet`` has internal weights.

    Args:
        (....): See docstring of function :func:`train`.
        train_iter (int, optional): The current training iteration. If given, it
            is used for tensorboard logging.
        task_emb (torch.Tensor, optional): Task embedding. If given, no task ID
            will be provided to the hypernetwork. This might be useful if the
            performance of other than the trained task embeddings should be
            tested.

            .. note::
                This option may only be used for ``cl_scenario=1``. It doesn't
                make sense if the task ID has to be inferred.
        cl_scenario (int, optional): In case the system should be tested on
            another CL scenario than the one user-defined in ``config``.
            
            .. note::
                It is up to the user to ensure that the CL scnearios are
                compatible in this implementation.
        test_size (int, optional): In case the testing shouldn't be performed
            on the entire test set, this option can be used to specify the
            number of test samples to be used.

    Returns:
        (tuple): Tuple containing:

        - **test_acc**: Test accuracy on classification task.
        - **task_acc**: Task prediction accuracy (always 100% for **CL1**).
    """
    if cl_scenario is None:
        cl_scenario = config.cl_scenario
    else:
        assert cl_scenario in [1, 2, 3]

    # `task_emb` ignored for other cl scenarios!
    assert task_emb is None or cl_scenario == 1, \
        '"task_emb" may only be specified for CL1, as we infer the ' + \
        'embedding for other scenarios.'

    mnet.eval()
    if hnet is not None:
        hnet.eval()

    if train_iter is None:
        logger.info('### Test run ...')
    else:
        logger.info('# Testing network before running training step %d ...' % \
                    train_iter)

    # We need to tell the main network, which batch statistics to use, in case
    # batchnorm is used and we checkpoint the batchnorm stats.
    mnet_kwargs = {}
    if mnet.batchnorm_layers is not None:
        if config.bn_distill_stats:
            raise NotImplementedError()
        elif not config.bn_no_running_stats and \
                not config.bn_no_stats_checkpointing:
            # Specify current task as condition to select correct
            # running stats.
            mnet_kwargs['condition'] = task_id

            if task_emb is not None:
                # NOTE `task_emb` might have nothing to do with `task_id`.
                logger.warning('Using batch statistics accumulated for task ' +
                               '%d for batchnorm, but testing is ' % task_id +
                               'performed using a given task embedding.')

    with torch.no_grad():
        batch_size = config.val_batch_size
        # FIXME Assuming all output heads have the same size.
        n_head = data.num_classes

        if test_size is None or test_size >= data.num_test_samples:
            test_size = data.num_test_samples
        else:
            # Make sure that we always use the same test samples.
            data.reset_batch_generator(train=False, test=True, val=False)
            logger.info('Note, only part of test set is used for this test ' +
                        'run!')

        test_loss = 0.0

        # We store all predicted labels and tasks while going over individual
        # test batches.
        correct_labels = np.empty(test_size, np.int)
        pred_labels = np.empty(test_size, np.int)
        correct_tasks = np.ones(test_size, np.int) * task_id
        pred_tasks = np.empty(test_size, np.int)

        curr_bs = batch_size
        N_processed = 0

        # Sweep through the test set.
        while N_processed < test_size:
            if N_processed + curr_bs > test_size:
                curr_bs = test_size - N_processed
            N_processed += curr_bs

            batch = data.next_test_batch(curr_bs)
            X = data.input_to_torch_tensor(batch[0], device)
            T = data.output_to_torch_tensor(batch[1], device)

            ############################
            ### Get main net weights ###
            ############################
            if hnet is None:
                weights = None
            elif cl_scenario > 1:
                raise NotImplementedError()
            elif task_emb is not None:
                weights = hnet.forward(task_emb=task_emb)
            else:
                weights = hnet.forward(task_id=task_id)

            #######################
            ### Get predictions ###
            #######################
            Y_hat_logits = mnet.forward(X, weights=weights, **mnet_kwargs)

            if config.cl_scenario == 1:
                # Select current head.
                task_out = [task_id * n_head, (task_id + 1) * n_head]
            elif config.cl_scenario == 2:
                # Only 1 output head.
                task_out = [0, n_head]
            else:
                raise NotImplementedError()
                # TODO Choose the predicted output head per sample.
                #task_out = [predicted_task_id[0]*n_head,
                #            (predicted_task_id[0]+1)*n_head]

            Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]]
            # We take the softmax after the output neurons are chosen.
            Y_hat = F.softmax(Y_hat_logits, dim=1).cpu().numpy()

            correct_labels[N_processed-curr_bs:N_processed] = \
                T.argmax(dim=1, keepdim=False).cpu().numpy()

            pred_labels[N_processed-curr_bs:N_processed] = \
                Y_hat.argmax(axis=1)

            # Set task prediction to 100% if we do not infer it.
            if cl_scenario > 1:
                raise NotImplementedError()
                #pred_tasks[N_processed-curr_bs:N_processed] = \
                #    predicted_task_id.cpu().numpy()
            else:
                pred_tasks[N_processed - curr_bs:N_processed] = task_id

            # Note, targets are 1-hot encoded.

            test_loss += Classifier.logit_cross_entropy_loss(Y_hat_logits,
                                                             T,
                                                             reduction='sum')

            print('test Y_hat: ', Y_hat.argmax(axis=1))
            print('test T: ', T.argmax(dim=1))
            print('test len: ', T.argmax(dim=1).size())
            print('N_processed: ', N_processed)

        class_n_correct = (correct_labels == pred_labels).sum()
        test_acc = 100.0 * class_n_correct / test_size

        task_n_correct = (correct_tasks == pred_tasks).sum()
        task_acc = 100.0 * task_n_correct / test_size

        test_loss /= test_size

        msg = '### Test accuracy of task %d' % (task_id+1) \
            + (' (before training iteration %d)' % train_iter if \
               train_iter is not None else '') \
            + ': %.3f' % (test_acc) \
            + (' (using a given task embedding)' if task_emb is not None \
               else '') \
            + (' - task prediction accuracy: %.3f' % task_acc if \
               cl_scenario > 1 else '')
        logger.info(msg)

        if train_iter is not None:
            writer.add_scalar('test/task_%d/class_accuracy' % task_id,
                              test_acc, train_iter)

            if config.cl_scenario > 1:
                writer.add_scalar('test/task_%d/task_pred_accuracy' % \
                                  task_id, task_acc, train_iter)

        return test_acc, task_acc
Exemplo n.º 5
0
def train(task_id, data, mnet, hnet, device, config, shared, writer, logger):
    """Train the hyper network using the task-specific loss plus a regularizer
    that should overcome catastrophic forgetting.

    :code:`loss = task_loss + beta * regularizer`.

    Args:
        task_id: The index of the task on which we train.
        data: The dataset handler.
        mnet: The model of the main network.
        hnet: The model of the hyper network. May be ``None``.
        device: Torch device (cpu or gpu).
        config: The command line arguments.
        shared (argparse.Namespace): Set of variables shared between functions.
        writer: The tensorboard summary writer.
        logger: The logger that should be used rather than the print method.
    """
    start_time = time()

    print('data: ', data)
    print('data.num_classes: ', data.num_classes)
    print('data.num_train_samples: ', data.num_train_samples)

    logger.info('Training network ...')

    mnet.train()
    if hnet is not None:
        hnet.train()

    #################
    ### Optimizer ###
    #################
    # Define the optimizers used to train main network and hypernet.
    if hnet is not None:
        theta_params = list(hnet.theta)
        if config.continue_emb_training:
            for i in range(task_id):  # for all previous task embeddings
                theta_params.append(hnet.get_task_emb(i))

        # Only for the current task embedding.
        # Important that this embedding is in a different optimizer in case
        # we use the lookahead.
        emb_optimizer = get_optimizer([hnet.get_task_emb(task_id)],
                                      config.lr,
                                      momentum=config.momentum,
                                      weight_decay=config.weight_decay,
                                      use_adam=config.use_adam,
                                      adam_beta1=config.adam_beta1,
                                      use_rmsprop=config.use_rmsprop)
    else:
        theta_params = mnet.weights
        emb_optimizer = None

    theta_optimizer = get_optimizer(theta_params,
                                    config.lr,
                                    momentum=config.momentum,
                                    weight_decay=config.weight_decay,
                                    use_adam=config.use_adam,
                                    adam_beta1=config.adam_beta1,
                                    use_rmsprop=config.use_rmsprop)

    ################################
    ### Learning rate schedulers ###
    ################################
    if config.plateau_lr_scheduler:
        assert (config.epochs != -1)
        # The scheduler config has been taken from here:
        # https://keras.io/examples/cifar10_resnet/
        # Note, we use 'max' instead of 'min' as we look at accuracy rather
        # than validation loss!
        plateau_scheduler_theta = optim.lr_scheduler.ReduceLROnPlateau( \
            theta_optimizer, 'max', factor=np.sqrt(0.1), patience=5,
            min_lr=0.5e-6, cooldown=0)
        plateau_scheduler_emb = None
        if emb_optimizer is not None:
            plateau_scheduler_emb = optim.lr_scheduler.ReduceLROnPlateau( \
                emb_optimizer, 'max', factor=np.sqrt(0.1), patience=5,
                min_lr=0.5e-6, cooldown=0)

    if config.lambda_lr_scheduler:
        assert (config.epochs != -1)

        def lambda_lr(epoch):
            """Multiplicative Factor for Learning Rate Schedule.

            Computes a multiplicative factor for the initial learning rate based
            on the current epoch. This method can be used as argument
            ``lr_lambda`` of class :class:`torch.optim.lr_scheduler.LambdaLR`.

            The schedule is inspired by the Resnet CIFAR-10 schedule suggested
            here https://keras.io/examples/cifar10_resnet/.

            Args:
                epoch (int): The number of epochs

            Returns:
                lr_scale (float32): learning rate scale
            """
            lr_scale = 1.
            if epoch > 180:
                lr_scale = 0.5e-3
            elif epoch > 160:
                lr_scale = 1e-3
            elif epoch > 120:
                lr_scale = 1e-2
            elif epoch > 80:
                lr_scale = 1e-1
            return lr_scale

        lambda_scheduler_theta = optim.lr_scheduler.LambdaLR(
            theta_optimizer, lambda_lr)
        lambda_scheduler_emb = None
        if emb_optimizer is not None:
            lambda_scheduler_emb = optim.lr_scheduler.LambdaLR(
                emb_optimizer, lambda_lr)

    ##############################
    ### Prepare CL Regularizer ###
    ##############################
    # Whether we will calculate the regularizer.
    calc_reg = task_id > 0 and not config.mnet_only and config.beta > 0 and \
        not config.train_from_scratch

    # Compute targets when the reg is activated and we are not training
    # the first task
    if calc_reg:
        if config.online_target_computation:
            # Compute targets for the regularizer whenever they are needed.
            # -> Computationally expensive.
            targets_hypernet = None
            prev_theta = [p.detach().clone() for p in hnet.theta]
            prev_task_embs = [p.detach().clone() for p in hnet.get_task_embs()]
        else:
            # Compute targets for the regularizer once and keep them all in
            # memory -> Memory expensive.
            targets_hypernet = hreg.get_current_targets(task_id, hnet)
            prev_theta = None
            prev_task_embs = None

        # If we do not want to regularize all outputs (in a multi-head setup).
        # Note, we don't care whether output heads other than the current one
        # change.
        regged_outputs = None
        if config.cl_scenario != 2:
            # FIXME We assume here that all tasks have the same output size.
            n_y = data.num_classes
            regged_outputs = [
                list(range(i * n_y, (i + 1) * n_y)) for i in range(task_id)
            ]

    # We need to tell the main network, which batch statistics to use, in case
    # batchnorm is used and we checkpoint the batchnorm stats.
    mnet_kwargs = {}
    if mnet.batchnorm_layers is not None:
        if config.bn_distill_stats:
            raise NotImplementedError()
        elif not config.bn_no_running_stats and \
                not config.bn_no_stats_checkpointing:
            # Specify current task as condition to select correct
            # running stats.
            mnet_kwargs['condition'] = task_id

    ######################
    ### Start training ###
    ######################

    iter_per_epoch = -1
    if config.epochs == -1:
        training_iterations = config.n_iter
    else:
        assert (config.epochs > 0)
        iter_per_epoch = int(np.ceil(data.num_train_samples / \
                                     config.batch_size))
        training_iterations = config.epochs * iter_per_epoch

    summed_iter_runtime = 0

    for i in range(training_iterations):
        ### Evaluate network.
        # We test the network before we run the training iteration.
        # That way, we can see the initial performance of the untrained network.
        if i % config.val_iter == 0:
            test(task_id,
                 data,
                 mnet,
                 hnet,
                 device,
                 shared,
                 config,
                 writer,
                 logger,
                 train_iter=i)
            mnet.train()
            if hnet is not None:
                hnet.train()

        if i % 200 == 0:
            logger.info('Training step: %d ...' % i)

        iter_start_time = time()

        theta_optimizer.zero_grad()
        if emb_optimizer is not None:
            emb_optimizer.zero_grad()

        #######################################
        ### Data for current task and batch ###
        #######################################
        batch = data.next_train_batch(config.batch_size)
        X = data.input_to_torch_tensor(batch[0], device, mode='train')
        T = data.output_to_torch_tensor(batch[1], device, mode='train')

        # Get the output neurons depending on the continual learning scenario.
        n_y = data.num_classes
        if config.cl_scenario == 1:
            # Choose current head.
            task_out = [task_id * n_y, (task_id + 1) * n_y]
        elif config.cl_scenario == 2:
            # Always all output neurons, only one head is used.
            task_out = [0, n_y]
        else:
            # Choose current head, which will be inferred during inference.
            task_out = [task_id * n_y, (task_id + 1) * n_y]

        ########################
        ### Loss computation ###
        ########################
        if config.mnet_only:
            weights = None
        else:
            weights = hnet.forward(task_id=task_id)
        Y_hat_logits = mnet.forward(X, weights, **mnet_kwargs)

        # Restrict output neurons
        Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]]
        assert (T.shape[1] == Y_hat_logits.shape[1])
        # compute loss on task and compute gradients
        if config.soft_targets:
            soft_label = 0.95
            num_classes = data.num_classes
            soft_targets = torch.where(
                T == 1, torch.Tensor([soft_label]),
                torch.Tensor([(1 - soft_label) / (num_classes - 1)]))
            soft_targets = soft_targets.to(device)
            loss_task = Classifier.softmax_and_cross_entropy(
                Y_hat_logits, soft_targets)
        else:
            loss_task = Classifier.logit_cross_entropy_loss(Y_hat_logits, T)

        # Compute gradients based on task loss (those might be used in the CL
        # regularizer).
        loss_task.backward(retain_graph=calc_reg, create_graph=calc_reg and \
                           config.backprop_dt)

        # The current task embedding only depends in the task loss, so we can
        # update it already.
        if emb_optimizer is not None:
            emb_optimizer.step()

        #############################
        ### CL (HNET) Regularizer ###
        #############################
        loss_reg = 0
        dTheta = None

        if calc_reg:
            if config.no_lookahead:
                dTembs = None
                dTheta = None
            else:
                dTheta = opstep.calc_delta_theta(
                    theta_optimizer,
                    False,
                    lr=config.lr,
                    detach_dt=not config.backprop_dt)

                if config.continue_emb_training:
                    dTembs = dTheta[-task_id:]
                    dTheta = dTheta[:-task_id]
                else:
                    dTembs = None

            loss_reg = hreg.calc_fix_target_reg(
                hnet,
                task_id,
                targets=targets_hypernet,
                dTheta=dTheta,
                dTembs=dTembs,
                mnet=mnet,
                inds_of_out_heads=regged_outputs,
                prev_theta=prev_theta,
                prev_task_embs=prev_task_embs,
                batch_size=config.cl_reg_batch_size)

            loss_reg *= config.beta

            loss_reg.backward()

        # Now, that we computed the regularizer, we can use the accumulated
        # gradients and update the hnet (or mnet) parameters.
        theta_optimizer.step()

        Y_hat = F.softmax(Y_hat_logits, dim=1)
        classifier_accuracy = Classifier.accuracy(Y_hat, T) * 100.0

        # print('train T: ',Y_hat.argmax(dim=1, keepdim=False))
        # print('train T: ',T.argmax(dim=1, keepdim=False))
        # print('train Y_hat: ',Y_hat.size())
        # print('train T: ',T.size())

        #########################
        # Learning rate scheduler
        #########################
        if config.plateau_lr_scheduler:
            assert (iter_per_epoch != -1)
            if i % iter_per_epoch == 0 and i > 0:
                curr_epoch = i // iter_per_epoch
                logger.info('Computing test accuracy for plateau LR ' +
                            'scheduler (epoch %d).' % curr_epoch)
                # We need a validation quantity for the plateau LR scheduler.
                # FIXME we should use an actual validation set rather than the
                # test set.
                # Note, https://keras.io/examples/cifar10_resnet/ uses the test
                # set to compute the validation loss. We use the "validation"
                # accuracy instead.
                # FIXME We increase `train_iter` as the print messages in the
                # test method suggest that the testing has been executed before
                test_acc, _ = test(task_id,
                                   data,
                                   mnet,
                                   hnet,
                                   device,
                                   shared,
                                   config,
                                   writer,
                                   logger,
                                   train_iter=i + 1)
                mnet.train()
                if hnet is not None:
                    hnet.train()

                plateau_scheduler_theta.step(test_acc)
                if plateau_scheduler_emb is not None:
                    plateau_scheduler_emb.step(test_acc)

        if config.lambda_lr_scheduler:
            assert (iter_per_epoch != -1)
            if i % iter_per_epoch == 0 and i > 0:
                curr_epoch = i // iter_per_epoch
                logger.info('Applying Lambda LR scheduler (epoch %d).' %
                            curr_epoch)

                lambda_scheduler_theta.step()
                if lambda_scheduler_emb is not None:
                    lambda_scheduler_emb.step()

        ###########################
        ### Tensorboard summary ###
        ###########################
        # We don't wanna slow down training by having too much output.
        if i % 50 == 0:
            writer.add_scalar('train/task_%d/class_accuracy' % task_id,
                              classifier_accuracy, i)
            writer.add_scalar('train/task_%d/loss_task' % task_id, loss_task,
                              i)
            writer.add_scalar('train/task_%d/loss_reg' % task_id, loss_reg, i)

        ### Show the current training progress to the user.
        if i % config.val_iter == 0:
            msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \
                  '(on current training batch).'
            logger.debug(msg.format(i, classifier_accuracy))

        iter_end_time = time()
        summed_iter_runtime += (iter_end_time - iter_start_time)

        if i % 200 == 0:
            logger.info('Training step: %d ... Done -- (runtime: %f sec)' % \
                        (i, iter_end_time - iter_start_time))

    if mnet.batchnorm_layers is not None:
        if not config.bn_distill_stats and \
                not config.bn_no_running_stats and \
                not config.bn_no_stats_checkpointing:
            # Checkpoint the current running statistics (that have been
            # estimated while training the current task).
            for bn_layer in mnet.batchnorm_layers:
                assert (bn_layer.num_stats == task_id + 1)
                bn_layer.checkpoint_stats()

    avg_iter_time = summed_iter_runtime / config.n_iter
    logger.info('Average runtime per training iteration: %f sec.' % \
                avg_iter_time)

    logger.info('Elapsed time for training task %d: %f sec.' % \
                (task_id+1, time()-start_time))
    def distill_loss_fct(config, X, Y_logits, T_soft_logits, data):
        assert np.all(np.equal(X.shape[:2], T_soft_logits.shape[:2]))
        # Note, targets and predictions might have different head sizes if a
        # growing softmax is used.
        assert np.all(np.equal(Y_logits.shape[:2], T_soft_logits.shape[:2]))

        # Disillation temperature.
        T = 2.

        if config.distill_across_time:
            # Emd-of-digit feature.
            # Note, the softmax would be an easy way to obtain per-timestep
            # probabilities, that a digit has ended. But since this feature
            # vector `X[:, :, 3]` is ideally a 1-hot encoding, it shouldn't
            # be squashed via a softmax, that would blur out the probabilities.
            #ts_weights = F.softmax(X[:, :, 3], dim=0).detach()
            ts_weights = X[:, :, 3].clone()
            ts_weights[ts_weights < 0] = 0
            # Avoid division by zero in case all elements of `X[:, :, 3]` are
            # negative.
            ts_weights /= ts_weights.sum() + 1e-5
            ts_weights = ts_weights.detach()

            # For distillation, we use a tempered softmax.
            T_soft = F.softmax(T_soft_logits / T, dim=2)
            if config.all_task_softmax and Y_logits.shape[2] != T_soft.shape[2]:
                # Pad new classes to soft targets.
                T_soft = F.pad(T_soft, (0, data.num_classes),
                               mode='constant',
                               value=0)
                assert Y_logits.shape[2] == T_soft.shape[2]

            # Distillation loss.
            loss = -(T_soft * F.log_softmax(Y_logits / T, dim=2)).sum(dim=2) * \
                T**2
            loss *= ts_weights

            # Sum across time (note, weights sum to 1) and mean across batch
            # dimension.
            return loss.sum(dim=0).mean()
        else:
            # Note, smnist samples have the end-of-sequence bit as last
            # timestep, the rest is padded.
            seq_lengths = X[:, :, 3].argmax(dim=0)
            inds = seq_lengths.cpu().numpy() - 1
            inds[inds < 0] = 0

            # Only compute loss for last timestep.
            Y_logits = Y_logits[inds, np.arange(inds.size), :]
            T_soft_logits = T_soft_logits[inds, np.arange(inds.size), :]

            target_mapping = None
            if config.all_task_softmax:
                target_mapping = list(range(T_soft_logits.shape[1]))

            return Classifier.knowledge_distillation_loss(
                Y_logits,
                T_soft_logits,
                target_mapping=target_mapping,
                device=Y_logits.device,
                T=T)
Exemplo n.º 7
0
def train_class_one_t(dhandler_class, dhandlers_rp, dec, d_hnet, net,
                      device, config, writer, t):
    """Train continual learning experiments on MNIST dataset for one task.
    In this function the main training logic is implemented. 
    After setting the optimizers for the network and hypernetwork if 
    applicable, the training is structured as follows: 
    First, we get the a training batch of the current task. Depending on 
    the learning scenario, we choose output heads and build targets 
    accordingly. 
    Second, if ``t`` is greater than 1, we add a loss term concerning 
    predictions of replayed data. See :func:`get_fake_data_loss` for 
    details. Third, to protect the hypernetwork from forgetting, we add an 
    additional L2 loss term namely the difference between its current output 
    given an embedding and checkpointed targets.
    Finally, we track some training statistics.

    Args:
        (....): See docstring of function :func:`train_tasks`.
        t: Task id.
    """

    # if cl with task inference we have the classifier empowered with a hnet 
    if config.training_with_hnet:
        net_hnet = net[1]
        net = net[0]
        net.train()
        net_hnet.train()
        params_to_regularize = list(net_hnet.theta)
        optimizer = optim.Adam(params_to_regularize,
                               lr=config.class_lr, betas=(0.9, 0.999))

        c_emb_optimizer = optim.Adam([net_hnet.get_task_emb(t)],
                                     lr=config.class_lr_emb, betas=(0.9, 0.999))
    else:
        net.train()
        net_hnet = None
        optimizer = optim.Adam(net.parameters(),
                               lr=config.class_lr, betas=(0.9, 0.999))

    # dont train the replay model if available
    if dec is not None:
        dec.eval()
    if d_hnet is not None:
        d_hnet.eval()

    # compute targets if classifier is trained with hnet
    if t > 0 and config.training_with_hnet:
        if config.online_target_computation:
            # Compute targets for the regularizer whenever they are needed.
            # -> Computationally expensive.
            targets_C = None
            prev_theta = [p.detach().clone() for p in net_hnet.theta]
            prev_task_embs = [p.detach().clone() for p in \
                              net_hnet.get_task_embs()]
        else:
            # Compute targets for the regularizer once and keep them all in
            # memory -> Memory expensive.
            targets_C = hreg.get_current_targets(t, net_hnet)
            prev_theta = None
            prev_task_embs = None

    dhandler_class.reset_batch_generator()

    # make copy of network
    if t >= 1:
        net_copy = copy.deepcopy(net)

    # set training_iterations if epochs are set
    if config.epochs == -1:
        training_iterations = config.n_iter
    else:
        assert (config.epochs > 0)
        training_iterations = config.epochs * \
                              int(np.ceil(dhandler_class.num_train_samples / config.batch_size))

    if config.class_incremental:
        training_iterations = int(training_iterations / config.out_dim)

    # Whether we will calculate the regularizer.
    calc_reg = t > 0 and config.class_beta > 0 and config.training_with_hnet

    # set if we want the reg only computed for a subset of the  previous tasks
    if config.hnet_reg_batch_size != -1:
        hnet_reg_batch_size = config.hnet_reg_batch_size
    else:
        hnet_reg_batch_size = None

    for i in range(training_iterations):

        # set optimizer to zero
        optimizer.zero_grad()
        if net_hnet is not None:
            c_emb_optimizer.zero_grad()

        # Get real data
        real_batch = dhandler_class.next_train_batch(config.batch_size)
        X_real = dhandler_class.input_to_torch_tensor(real_batch[0], device,
                                                      mode='train')
        T_real = dhandler_class.output_to_torch_tensor(real_batch[1], device,
                                                       mode='train')

        if i % 100 == 0 and config.show_plots:
            fig_real = _plotImages(X_real, config)
            writer.add_figure('train_class_' + str(t) + '_real',
                              fig_real, global_step=i)

        #################################################
        # Choosing output heads and constructing targets
        ################################################# 

        # If we train a task inference net or class incremental learning we 
        # we construct a target for every single class/task
        if config.class_incremental or config.training_task_infer:
            # in the beginning of training, we look at two output neuron
            task_out = [0, t + 1]
            T_real = torch.zeros((config.batch_size, task_out[1])).to(device)
            T_real[:, task_out[1] - 1] = 1

        elif config.cl_scenario == 1 or config.cl_scenario == 2:
            if config.cl_scenario == 1:
                # take the task specific output neuron
                task_out = [t * config.out_dim, t * config.out_dim + config.out_dim]
            else:
                # always all output neurons, only one head is used
                task_out = [0, config.out_dim]
        else:
            # The number of output neurons is generic and can grow i.e. we
            # do not have to know the number of tasks before we start 
            # learning.
            if not config.infer_output_head:
                task_out = [0, (t + 1) * config.out_dim]
                T_real = torch.cat((torch.zeros((config.batch_size,
                                                 t * config.out_dim)).to(device),
                                    T_real), dim=1)
            # this is a special case where we will infer the task id by another 
            # neural network so we can train on the correct output head direclty
            # and use the infered output head to compute the prediction
            else:
                task_out = [t * config.out_dim, t * config.out_dim + config.out_dim]

        # compute loss of current data
        if config.training_with_hnet:
            weights_c = net_hnet.forward(t)
        else:
            weights_c = None

        Y_hat_logits = net.forward(X_real, weights_c)
        Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]]

        if config.soft_targets:
            soft_label = 0.95
            num_classes = T_real.shape[1]
            soft_targets = torch.where(T_real == 1,
                                       torch.Tensor([soft_label]).to(device),
                                       torch.Tensor([(1 - soft_label) / (num_classes - 1)]).to(device))
            soft_targets = soft_targets.to(device)
            loss_task = Classifier.softmax_and_cross_entropy(Y_hat_logits,
                                                             soft_targets)
        else:
            loss_task = Classifier.softmax_and_cross_entropy(Y_hat_logits, T_real)

        ############################
        # compute loss for fake data
        ############################

        # Get fake data (of all tasks up until now and merge into list)
        if t >= 1 and not config.training_with_hnet:
            fake_loss = get_fake_data_loss(dhandlers_rp, net, dec, d_hnet, device,
                                           config, writer, t, i, net_copy)
            loss_task = (1 - config.l_rew) * loss_task + config.l_rew * fake_loss

        loss_task.backward(retain_graph=calc_reg, create_graph=calc_reg and \
                                                               config.backprop_dt)

        # compute hypernet loss and fix embedding -> change current embs
        if calc_reg:
            if config.no_lookahead:
                dTheta = None
            else:
                dTheta = opstep.calc_delta_theta(optimizer,
                                                 config.use_sgd_change, lr=config.class_lr,
                                                 detach_dt=not config.backprop_dt)
            loss_reg = config.class_beta * hreg.calc_fix_target_reg(net_hnet, t,
                                                                    targets=targets_C, mnet=net, dTheta=dTheta,
                                                                    dTembs=None,
                                                                    prev_theta=prev_theta,
                                                                    prev_task_embs=prev_task_embs,
                                                                    batch_size=hnet_reg_batch_size)
            loss_reg.backward()

        # compute backward passloss_task.backward()
        if not config.dont_train_main_model:
            optimizer.step()

        if net_hnet is not None and config.train_class_embeddings:
            c_emb_optimizer.step()

        # same stats saving
        if i % 50 == 0:
            # compute accuracies for tracking
            Y_hat_logits = net.forward(X_real, weights_c)
            Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]]
            Y_hat = F.softmax(Y_hat_logits, dim=1)
            classifier_accuracy = Classifier.accuracy(Y_hat, T_real) * 100.0
            writer.add_scalar('train/task_%d/class_accuracy' % t,
                              classifier_accuracy, i)
            writer.add_scalar('train/task_%d/loss_task' % t,
                              loss_task, i)
            if t >= 1 and not config.training_with_hnet:
                writer.add_scalar('train/task_%d/fake_loss' % t,
                                  fake_loss, i)

        # plot some gradient statistics
        if i % 200 == 0:
            if not config.dont_train_main_model:
                total_norm = 0
                if config.training_with_hnet:
                    params = net_hnet.theta
                else:
                    params = net.parameters()

                for p in params:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
                total_norm = total_norm ** (1. / 2)
                # TODO write gradient histograms?
                writer.add_scalar('train/task_%d/main_params_grad_norms' % t,
                                  total_norm, i)

            if net_hnet is not None and config.train_class_embeddings:
                total_norm = 0
                for p in [net_hnet.get_task_emb(t)]:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
                total_norm = total_norm ** (1. / 2)
                writer.add_scalar('train/task_%d/hnet_emb_grad_norms' % t,
                                  total_norm, i)

        if i % 200 == 0:
            msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \
                  '(on current training batch).'
            print(msg.format(i, classifier_accuracy))
Exemplo n.º 8
0
def get_fake_data_loss(dhandlers_rp, net, dec, d_hnet, device, config, writer,
                       t, i, net_copy):
    """ Sample fake data from generator for tasks up to t and compute a loss
    compared to predictions of a checkpointed network.
    
    We must take caution when considering the different learning scenarios
    and methods and training stages, see detailed comments in the code.
    
    In general, we build a batch of replayed data from all previous tasks.
    Since we do not know the labels of the replayed data, we consider the
    output of the checkpointed network as ground thruth i.e. we must compute
    a loss between two logits.See :class:`mnets.classifier_interface.Classifier`
    for a detailed describtion of the different loss functions.
        
    Args:
        (....): See docstring of function :func:`train_tasks`.
        t: Task id.
        i: Current training iteration.
        net_copy: Copy/checkpoint of the classifier network before 
            learning task ``t``.
    Returns:
        The loss between predictions and predictions of a 
        checkpointed network or replayed data.
    
    """

    all_Y_hat_ls = []
    all_targets = []

    # we have to choose from which embeddings (multiple?!) to sample from 
    if config.class_incremental or config.single_class_replay:
        # if we trained every class with a different generator
        emb_num = t * config.out_dim
    else:
        # here samples from the whole task come from one generator
        emb_num = t
    # we have to choose from which embeddings to sample from 

    if config.fake_data_full_range:
        ran = range(0, emb_num)
        bs_per_task = int(np.ceil(config.batch_size / emb_num))
    else:
        random_t = np.random.randint(0, emb_num)
        ran = range(random_t, random_t + 1)
        bs_per_task = config.batch_size

    for re in ran:

        # exchange replay data with real data to compute upper bounds 
        if config.upper_bound:
            real_batch = dhandlers_rp[re].next_train_batch(bs_per_task)
            X_fake = dhandlers_rp[re].input_to_torch_tensor(real_batch[0],
                                                            device, mode='train')
        else:
            # get fake data
            if config.replay_method == 'gan':
                X_fake = sample_gan(dec, d_hnet, config, re, device,
                                    bs=bs_per_task)
            else:
                X_fake = sample_vae(dec, d_hnet, config, re, device,
                                    bs=bs_per_task)

        # save some fake data to the writer
        if i % 100 == 0:
            if X_fake.shape[0] >= 15:
                fig_fake = _plotImages(X_fake, config, bs_per_task)
                writer.add_figure('train_class_' + str(re) + '_fake',
                                  fig_fake, global_step=i)

        # compute soft targets with copied network
        target_logits = net_copy.forward(X_fake).detach()
        Y_hat_ls = net.forward(X_fake.detach())

        ###############
        # BUILD TARGETS
        ###############
        od = config.out_dim

        if config.class_incremental or config.training_task_infer:
            # This is a bit complicated: If we train class/task incrementally
            # we skip thraining the classifier on the first task. 
            # So when starting to train the classifier on task 2, we have to
            # build a hard target for this first output neuron trained by
            # replay data. A soft target (on an untrained output) would not 
            # make sense.

            # output head over all output neurons already available
            task_out = [0, (t + 1) * od]
            # create target with zero everywhere except from the current re
            zeros = torch.zeros(target_logits[:, 0:(t + 1) * od].shape).to(device)

            if config.hard_targets or (t == 1 and re == 0):
                zeros[:, re] = 1
            else:
                zeros[:, 0:t * od] = target_logits[:, 0:t * od]

            targets = zeros
            Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]]

        elif config.cl_scenario == 1 or config.cl_scenario == 2:
            if config.cl_scenario == 1:
                # take the task specific output neuron
                task_out = [re * od, re * od + od]
            else:
                # always all output neurons, only one head is used
                task_out = [0, od]

            Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]]
            target_logits = target_logits[:, task_out[0]:task_out[1]]
            # build hard targets i.e. one hots if this option is chosen
            if config.hard_targets:
                soft_targets = torch.sigmoid(target_logits)
                zeros = torch.zeros(Y_hat_ls.shape).to(device)
                _, argmax = torch.max(soft_targets, 1)
                targets = zeros.scatter_(1, argmax.view(-1, 1), 1)
            else:
                # loss expects logits
                targets = target_logits
        else:
            # take all neurons used up until now

            # output head over all output neurons already available
            task_out = [0, (t + 1) * od]
            # create target with zero everywhere except from the current re
            zeros = torch.zeros(target_logits[:, 0:(t + 1) * od].shape).to(device)

            # sigmoid over the output head(s) from all previous task
            soft_targets = torch.sigmoid(target_logits[:, 0:t * od])

            # compute one hots
            if config.hard_targets:
                _, argmax = torch.max(soft_targets, 1)
                zeros.scatter_(1, argmax.view(-1, 1), 1)
            else:
                # loss expects logits
                zeros[:, 0:t * od] = target_logits[:, 0:t * od]
            targets = zeros
            # choose the correct output size for the actual 
            Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]]

        # add to list
        all_targets.append(targets)
        all_Y_hat_ls.append(Y_hat_ls)

    # cat to one tensor
    all_targets = torch.cat(all_targets)
    Y_hat_ls = torch.cat(all_Y_hat_ls)

    if i % 200 == 0:
        classifier_accuracy = Classifier.accuracy(Y_hat_ls, all_targets) * 100.0
        msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \
              '(on current FAKE DATA training batch).'
        print(msg.format(i, classifier_accuracy))

    # dependent on the target softness, the loss function is chosen
    if config.hard_targets or (config.class_incremental and t == 1):
        return Classifier.logit_cross_entropy_loss(Y_hat_ls, all_targets)
    else:
        return Classifier.knowledge_distillation_loss(Y_hat_ls, all_targets)
Exemplo n.º 9
0
def test(dhandlers, class_nets, infer_net, device, config, writer,
         task_id=None):
    """ Test continual learning experiments on MNIST dataset. This can either 
    be splitMNIST or permutedMNIST. 
    Depending on the method and cl scenario used, this methods manages
    to measure the test accuracy of a given task or all tasks after 
    training. In order to do so, correct targets need to be constructed 
    and output heads need to be set (or inferred). 
    Furthermore, this method distinguises between classification accuracy
    on a task or on the accuracy to infer task id's if applicable. 

    Args:
        (....): See docstring of function :func:`train_tasks`.
        task_id: (optional) If not None, the method will compute and return 
                   test acc for the the given task id, not all tasks.
    
    Returns:
        Scalar represting the test accuracy for the given task id.
        If ``task_id`` is None, the accuracy of the last task of the cl 
        experiment is returned. 
    """

    # get hnet if this option is given
    if class_nets is not None:
        if config.training_with_hnet:
            c_net_hnet = class_nets[1]
            c_net = class_nets[0]
            c_net.eval()
            c_net_hnet.eval()
        else:
            c_net = class_nets

    if infer_net is not None:
        infer_net.eval()

    with torch.no_grad():

        overall_acc = 0
        overall_acc_list = []
        overall_task_infer_accuracy = 0
        overall_task_infer_accuracy_list = []

        # choose tasks to test
        if task_id is not None:
            task_range = range(task_id, task_id + 1)
        else:
            task_range = range(config.num_tasks)

        # iterate through all old tasks
        for t in task_range:
            print("Testing task: ", t)
            # reset data
            if task_id is not None:
                dhandler = dhandlers[0]
            else:
                dhandler = dhandlers[t]

            # create some variables
            N_processed = 0
            test_size = dhandler.num_test_samples

            # is task id has to be inferred, for every x we have to do that
            # and therefore have one h(e) = W per data point - this is only 
            # possible with batch size one, for now
            if (config.infer_task_id and infer_net is not None) or \
                    config.infer_with_entropy:
                curr_bs = 1
            else:
                curr_bs = config.test_batch_size

            classifier_accuracy = 0
            task_infer_accuracy = 0
            Y_hat_all = []
            T_all = []

            # go through test set
            while N_processed < test_size:
                # test size of tasks might be "arbitrary"
                if N_processed + curr_bs > test_size:
                    curr_bs = test_size - N_processed
                N_processed += curr_bs

                # get data
                real_batch = dhandler.next_test_batch(curr_bs)
                X_real = dhandler.input_to_torch_tensor(real_batch[0], device,
                                                        mode='inference')
                T_real = dhandler.output_to_torch_tensor(real_batch[1], device,
                                                         mode='inference')

                # get short version of output dim
                od = config.out_dim

                #######################################
                # SET THE OUTPUT HEAD / COMPUTE TARGETS
                #######################################

                # get dummy for easy access to the  output dim of our main 
                # network as a dummy, only needed for the first iteration
                if class_nets is not None:
                    if config.training_with_hnet:
                        weights_dummy = c_net_hnet.forward(0)
                        Y_dummies = c_net.forward(X_real, weights_dummy)
                    else:
                        Y_dummies = c_net.forward(X_real)
                else:
                    Y_dummies = infer_net.forward(X_real)

                # build one hots if this option was chosen
                # here we build targets if only have one neuron per task 
                # which we set to 1
                if config.class_incremental:
                    task_out = [0, config.num_tasks]
                    T_real = torch.zeros((Y_dummies.shape[0],
                                          config.num_tasks)).to(device)
                    T_real[:, t] = 1

                # compute targets - this is a bit unelegant, cl 3 requires hacks
                elif config.cl_scenario == 1 or config.cl_scenario == 2:
                    if config.cl_scenario == 1:
                        # take the task specific output neuron
                        task_out = [t * od, t * od + od]
                    else:
                        # always all output neurons (only one head is used)
                        task_out = [0, od]
                else:
                    # This here is the classic CL 3 scenario
                    # first we get the predictions, this is over all neurons
                    task_out = [0, config.num_tasks * od]
                    # Here we build the targets, this is zero everywhere 
                    # except for the current task - here the correct target
                    # is inserted

                    # build the two zero tensors that surround the targets
                    zeros1 = torch.zeros(Y_dummies[:, 0:t * od].shape). \
                        to(device)
                    zeros2 = torch.zeros(Y_dummies[:, 0:(config.num_tasks \
                                                         - 1 - t) * od].shape).to(device)
                    T_real = torch.cat([zeros1, T_real, zeros2], dim=-1)

                #################
                # TASK PREDICTION
                #################

                # get task predictions
                if config.cl_scenario != 1:
                    if infer_net is not None:
                        # get infer net to predict the apparent task id 
                        task_pred = infer_net.forward(X_real)
                        task_pred = task_pred[:, 0:config.num_tasks]
                        task_pred = torch.sigmoid(task_pred)
                        _, inf_task_id = torch.max(task_pred, 1)

                        # measure acc of prediction
                        task_infer_accuracy += (inf_task_id == t).float()

                    elif config.infer_with_entropy and class_nets is not None \
                            and config.training_with_hnet:
                        entropies = []
                        if task_id is not None:
                            entrop_to_test = range(0, task_id + 1)
                        else:
                            entrop_to_test = range(config.num_tasks)
                        # infer task id through entropy of softmax outputs of 
                        # different models
                        for e in entrop_to_test:
                            weights_c = c_net_hnet.forward(e)
                            Y_hat_logits = c_net.forward(X_real, weights_c)
                            if config.cl_scenario == 2:
                                task_out = [0, od]
                            else:
                                task_out = [e * od, e * od + od]
                            Y_hat = F.softmax(Y_hat_logits[:,
                                              task_out[0]:task_out[1]] / config.soft_temp, -1)
                            entropy = -1 * torch.sum(Y_hat * torch.log(Y_hat))
                            entropies.append(entropy)
                        inf_task_id = torch.argmin(torch.stack(entropies))
                        task_infer_accuracy += (inf_task_id == t).float()

                    if config.cl_scenario == 3 and config.infer_output_head:
                        task_out = [inf_task_id * od, inf_task_id * od + od]
                else:
                    # if task id is known, task inference acc is 100%
                    task_infer_accuracy += 1
                    inf_task_id = t

                if class_nets is not None:
                    # from the given inf_task_id we try to produce the 
                    # correct model for that tasks
                    if config.training_with_hnet:
                        weights_c = c_net_hnet.forward(inf_task_id)
                        Y_hat_logits = c_net.forward(X_real, weights_c)
                    else:
                        Y_hat_logits = c_net.forward(X_real)

                #################
                # CLASSIFICATION
                #################
                if class_nets is not None:
                    # save predictions of current batch
                    Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]]
                    Y_hat = F.softmax(Y_hat_logits, dim=1)
                    if config.cl_scenario == 3 and config.infer_output_head:
                        # this is the special case where the output head is 
                        # inferred. Here we compute the argmax of the single 
                        # head and add the number of previous neurons such that
                        # it coincides with the argmax of a hot enc target   
                        # that is build for all heads. Example: we detect that
                        # task 3 is present, and every task consist of two
                        # classes. The argmax of Y_hat will either give us 0
                        # or 1, since Y_hat_logits was already cut to two 
                        # dimensions. Now we have to add 3*2 to the argmax 
                        # of Y_hat to get a prediction between class 0 and 
                        # num_tasks*class_per_task.

                        Y_hat = Y_hat.argmax(dim=1, keepdim=False) + \
                                inf_task_id * od
                    Y_hat_all.append(Y_hat)
                    T_all.append(T_real)

            if class_nets is not None:
                # append predictions
                Y_hat_all = torch.cat(Y_hat_all)
                T_all = torch.cat(T_all)
                # check if all test samples are used
                assert (Y_hat_all.shape[0] == dhandler.num_test_samples)

                # compute class acc's
                if config.cl_scenario == 3 and class_nets is not None and \
                        config.infer_output_head:
                    # this is a special case, we compare the
                    targets = T_all.argmax(dim=1, keepdim=False)
                    classifier_accuracy = (Y_hat_all == targets).float().mean()
                else:
                    classifier_accuracy = Classifier.accuracy(Y_hat_all, T_all)

                classifier_accuracy *= 100.
                print("Accuracy of task: ", t, " % ", classifier_accuracy)
                overall_acc_list.append(classifier_accuracy)
                overall_acc += classifier_accuracy

            # compute task inference acc"s
            ti_accuracy = task_infer_accuracy / dhandler.num_test_samples * 100.
            if config.training_task_infer or config.infer_with_entropy:
                print("Accuracy of task inference: ", t, " % ", ti_accuracy)
            overall_task_infer_accuracy += ti_accuracy
            overall_task_infer_accuracy_list.append(ti_accuracy)

        # testing all tasks
        if task_id is None:
            if class_nets is not None:
                print("Overall mean acc: ", overall_acc / config.num_tasks)
            if config.training_task_infer or config.infer_with_entropy:
                print("Overall task inf acc: ", overall_task_infer_accuracy / \
                      config.num_tasks)
            config.overall_acc_list = overall_acc_list
            config.acc_mean = overall_acc / config.num_tasks
            config.overall_task_infer_accuracy_list = \
                overall_task_infer_accuracy_list
            config.acc_task_infer_mean = \
                overall_task_infer_accuracy / config.num_tasks
            print(config.overall_task_infer_accuracy_list, config.acc_task_infer_mean)
    return classifier_accuracy
Exemplo n.º 10
0
def get_fake_data_loss(dhandlers_rp, net, dec, d_hnet, device, config, writer,
                       t, i, net_copy):
    """ Sample fake data from generator for tasks up to t and compute a loss
    compared to predictions of a checkpointed network.
    
    We must take caution when considering the different learning scenarios
    and methods and training stages, see detailed comments in the code.
    
    In general, we build a batch of replayed data from all previous tasks.
    Since we do not know the labels of the replayed data, we consider the
    output of the checkpointed network as ground thruth i.e. we must compute
    a loss between two logits.See :class:`mnets.classifier_interface.Classifier`
    for a detailed describtion of the different loss functions.
        
    Args:
        (....): See docstring of function :func:`train_tasks`.
        t: Task id.
        i: Current training iteration.
        net_copy: Copy/checkpoint of the classifier network before 
            learning task ``t``.
    Returns:
        The loss between predictions and predictions of a 
        checkpointed network or replayed data.
    
    """

    all_Y_hat_ls = []
    all_targets = []

    # we have to choose from which embeddings (multiple?!) to sample from
    if config.class_incremental or config.single_class_replay:
        # if we trained every class with a different generator
        emb_num = t * config.out_dim
    else:
        # here samples from the whole task come from one generator
        emb_num = t
    # we have to choose from which embeddings to sample from

    if config.fake_data_full_range:
        ran = range(0, emb_num)
        bs_per_task = int(np.ceil(config.batch_size / emb_num))
    else:
        random_t = np.random.randint(0, emb_num)
        ran = range(random_t, random_t + 1)
        bs_per_task = config.batch_size

    # print('config.upper_bound: ',config.upper_bound)
    # print('config.num_embeddings: ',config.num_embeddings)

    for re in ran:

        # exchange replay data with real data to compute upper bounds
        if config.upper_bound:
            real_batch = dhandlers_rp[re].next_train_batch(bs_per_task)  #15
            X_fake = dhandlers_rp[re].input_to_torch_tensor(
                real_batch[0], device, mode='train')  #each batch 128
        else:
            # get fake data
            if config.replay_method == 'gan':
                X_fake = sample_gan(dec,
                                    d_hnet,
                                    config,
                                    re,
                                    device,
                                    bs=bs_per_task)
            else:
                X_fake = sample_vae(dec,
                                    d_hnet,
                                    config,
                                    re,
                                    device,
                                    bs=bs_per_task)
        # print('X_fake: ',X_fake.size())

        # save some fake data to the writer
        # if i % 100 == 0:
        #     if X_fake.shape[0] >= 15:
        #         fig_fake = _plotImages(X_fake, config, bs_per_task)
        #         writer.add_figure('train_class_' + str(re) + '_fake',
        #                                         fig_fake, global_step=i)

        # compute soft targets with copied network
        target_logits = net_copy.forward(X_fake).detach()
        Y_hat_ls = net.forward(X_fake.detach())

        ###############
        # BUILD TARGETS
        ###############
        if config.cl_scenario == 1:
            # take the task specific output neuron
            task_out = [sum(config.dims[:re]), sum(config.dims[:re + 1])]

        Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]]
        target_logits = target_logits[:, task_out[0]:task_out[1]]

        # build hard targets i.e. one hots if this option is chosen
        if config.hard_targets:
            soft_targets = torch.sigmoid(target_logits)
            zeros = torch.zeros(Y_hat_ls.shape).to(device)
            _, argmax = torch.max(soft_targets, 1)
            targets = zeros.scatter_(1, argmax.view(-1, 1), 1)
        else:
            # loss expects logits
            targets = target_logits

        # add to list
        all_targets.append(targets)
        all_Y_hat_ls.append(Y_hat_ls)

    # cat to one tensor

    # all_targets = torch.cat(all_targets)
    # Y_hat_ls = torch.cat(all_Y_hat_ls)

    all_targets = all_targets
    Y_hat_ls = all_Y_hat_ls

    if i % 200 == 0:
        classifier_accuracy = Classifier.accuracy(Y_hat_ls,
                                                  all_targets) * 100.0
        msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \
            '(on current FAKE DATA training batch).'
        print(msg.format(i, classifier_accuracy))

    # dependent on the target softness, the loss function is chosen
    if config.hard_targets or (config.class_incremental and t == 1):
        return Classifier.logit_cross_entropy_loss(Y_hat_ls, all_targets)
    else:
        return Classifier.knowledge_distillation_loss(Y_hat_ls, all_targets)
Exemplo n.º 11
0
def test(dhandlers,
         class_nets,
         infer_net,
         device,
         config,
         writer,
         task_id=None):
    """ Test continual learning experiments on MNIST dataset. This can either 
    be splitMNIST or permutedMNIST. 
    Depending on the method and cl scenario used, this methods manages
    to measure the test accuracy of a given task or all tasks after 
    training. In order to do so, correct targets need to be constructed 
    and output heads need to be set (or inferred). 
    Furthermore, this method distinguises between classification accuracy
    on a task or on the accuracy to infer task id's if applicable. 

    Args:
        (....): See docstring of function :func:`train_tasks`.
        task_id: (optional) If not None, the method will compute and return 
                   test acc for the the given task id, not all tasks.
    
    Returns:
        Scalar represting the test accuracy for the given task id.
        If ``task_id`` is None, the accuracy of the last task of the cl 
        experiment is returned. 
    """

    # get hnet if this option is given
    if class_nets is not None:
        if config.training_with_hnet:
            c_net_hnet = class_nets[1]
            c_net = class_nets[0]
            c_net.eval()
            c_net_hnet.eval()
        else:
            c_net = class_nets

    if infer_net is not None:
        infer_net.eval()

    with torch.no_grad():

        overall_acc = 0
        overall_acc_list = []
        overall_task_infer_accuracy = 0
        overall_task_infer_accuracy_list = []

        # choose tasks to test
        if task_id is not None:
            task_range = range(task_id, task_id + 1)
        else:
            task_range = range(config.num_tasks)

        # iterate through all old tasks
        for t in task_range:
            print("Testing task: ", t)
            # reset data
            if task_id is not None:
                dhandler = dhandlers[0]
            else:
                dhandler = dhandlers[t]

            # create some variables
            N_processed = 0
            test_size = dhandler.num_test_samples

            # is task id has to be inferred, for every x we have to do that
            # and therefore have one h(e) = W per data point - this is only
            # possible with batch size one, for now
            if (config.infer_task_id and infer_net is not None) or \
                                                      config.infer_with_entropy:
                curr_bs = 1
            else:
                curr_bs = config.test_batch_size

            classifier_accuracy = 0
            task_infer_accuracy = 0
            Y_hat_all = []
            T_all = []

            # go through test set
            while N_processed < test_size:
                # test size of tasks might be "arbitrary"
                if N_processed + curr_bs > test_size:
                    curr_bs = test_size - N_processed
                N_processed += curr_bs

                # get data
                real_batch = dhandler.next_test_batch(curr_bs)
                X_real = dhandler.input_to_torch_tensor(real_batch[0],
                                                        device,
                                                        mode='inference')
                T_real = dhandler.output_to_torch_tensor(real_batch[1],
                                                         device,
                                                         mode='inference')

                # get short version of output dim

                #######################################
                # SET THE OUTPUT HEAD / COMPUTE TARGETS
                #######################################

                # build one hots if this option was chosen
                # here we build targets if only have one neuron per task
                # which we set to 1

                # compute targets - this is a bit unelegant, cl 3 requires hacks
                if config.cl_scenario == 1:
                    # take the task specific output neuron
                    task_out = [sum(config.dims[:t]), sum(config.dims[:t + 1])]

                #################
                # TASK PREDICTION
                #################
                # if task id is known, task inference acc is 100%
                task_infer_accuracy += 1
                inf_task_id = t

                if class_nets is not None:
                    # from the given inf_task_id we try to produce the
                    # correct model for that tasks
                    if config.training_with_hnet:
                        weights_c = c_net_hnet.forward(inf_task_id)
                        Y_hat_logits = c_net.forward(X_real, weights_c)
                    else:
                        Y_hat_logits = c_net.forward(X_real)

                #################
                # CLASSIFICATION
                #################
                if class_nets is not None:
                    # save predictions of current batch
                    print('task_out: ', task_out)
                    Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]]
                    Y_hat = F.softmax(Y_hat_logits, dim=1)

                    Y_hat_all.append(Y_hat)
                    T_all.append(T_real)

            if class_nets is not None:
                # append predictions

                print('Y_hat_all: ', [x.size() for x in Y_hat_all])
                print('T_all: ', [x.size() for x in T_all])

                Y_hat_all = torch.cat(Y_hat_all)
                T_all = torch.cat(T_all)
                # check if all test samples are used
                assert (Y_hat_all.shape[0] == dhandler.num_test_samples)

                classifier_accuracy = Classifier.accuracy(Y_hat_all, T_all)

                classifier_accuracy *= 100.
                print("Accuracy of task: ", t, " % ", classifier_accuracy)
                overall_acc_list.append(classifier_accuracy)
                overall_acc += classifier_accuracy

            # compute task inference acc"s
            ti_accuracy = task_infer_accuracy / dhandler.num_test_samples * 100.
            if config.training_task_infer or config.infer_with_entropy:
                print("Accuracy of task inference: ", t, " % ", ti_accuracy)
            overall_task_infer_accuracy += ti_accuracy
            overall_task_infer_accuracy_list.append(ti_accuracy)

        # testing all tasks
        if task_id is None:
            if class_nets is not None:
                print("Overall mean acc: ", overall_acc / config.num_tasks)
            if config.training_task_infer or config.infer_with_entropy:
                print("Overall task inf acc: ", overall_task_infer_accuracy/ \
                                                               config.num_tasks)
            config.overall_acc_list = overall_acc_list
            config.acc_mean = overall_acc / config.num_tasks
            config.overall_task_infer_accuracy_list = \
                                                overall_task_infer_accuracy_list
            config.acc_task_infer_mean = \
                                    overall_task_infer_accuracy/config.num_tasks
            print(config.overall_task_infer_accuracy_list,
                  config.acc_task_infer_mean)
    return classifier_accuracy