Esempio n. 1
0
def test(enc,
         dec,
         d_hnet,
         device,
         config,
         writer,
         train_iter=None,
         condition=None):
    """ Test the MNIST VAE - here we only sample from a fixed noise to compare
    images qualitatively. One should also keep track of the reconstruction 
    error of e.g. a test set.

    Args:
        (....): See docstring of function :func:`train`.
        train_iter: The current training iteration.
        condition: Condition (class/task) we are currently training.
    """
    if train_iter is None:
        print('### Final test run ...')
        train_iter = config.n_iter
    else:
        print('# Testing network before running training step %d ...' % \
              train_iter)
    # if no condition is given, we iterate over all (trained) embeddings
    if condition is None:
        condition = config.num_embeddings - 1
    # eval all nets
    enc.eval()
    dec.eval()
    if d_hnet is not None:
        d_hnet.eval()

    with torch.no_grad():
        # iterate over all conditions
        for m in range(condition + 1):
            # Get pre training saved noise
            z = config.test_z[m]
            reconstructions = sample(dec, d_hnet, config, m, device, z=z)
            if config.show_plots:
                fig_real = _plotImages(reconstructions, config)
                writer.add_figure('test_cond_' + str(m) + '_sampled_after_' +
                                  str(condition),
                                  fig_real,
                                  global_step=train_iter)
                if train_iter == config.n_iter:
                    writer.add_figure('test_cond_final_' + str(m) +
                                      '_sampled_after_' + str(condition),
                                      fig_real,
                                      global_step=train_iter)
Esempio n. 2
0
def train_class_one_t(dhandler_class, dhandlers_rp, dec, d_hnet, net,
                      device, config, writer, t):
    """Train continual learning experiments on MNIST dataset for one task.
    In this function the main training logic is implemented. 
    After setting the optimizers for the network and hypernetwork if 
    applicable, the training is structured as follows: 
    First, we get the a training batch of the current task. Depending on 
    the learning scenario, we choose output heads and build targets 
    accordingly. 
    Second, if ``t`` is greater than 1, we add a loss term concerning 
    predictions of replayed data. See :func:`get_fake_data_loss` for 
    details. Third, to protect the hypernetwork from forgetting, we add an 
    additional L2 loss term namely the difference between its current output 
    given an embedding and checkpointed targets.
    Finally, we track some training statistics.

    Args:
        (....): See docstring of function :func:`train_tasks`.
        t: Task id.
    """

    # if cl with task inference we have the classifier empowered with a hnet 
    if config.training_with_hnet:
        net_hnet = net[1]
        net = net[0]
        net.train()
        net_hnet.train()
        params_to_regularize = list(net_hnet.theta)
        optimizer = optim.Adam(params_to_regularize,
                               lr=config.class_lr, betas=(0.9, 0.999))

        c_emb_optimizer = optim.Adam([net_hnet.get_task_emb(t)],
                                     lr=config.class_lr_emb, betas=(0.9, 0.999))
    else:
        net.train()
        net_hnet = None
        optimizer = optim.Adam(net.parameters(),
                               lr=config.class_lr, betas=(0.9, 0.999))

    # dont train the replay model if available
    if dec is not None:
        dec.eval()
    if d_hnet is not None:
        d_hnet.eval()

    # compute targets if classifier is trained with hnet
    if t > 0 and config.training_with_hnet:
        if config.online_target_computation:
            # Compute targets for the regularizer whenever they are needed.
            # -> Computationally expensive.
            targets_C = None
            prev_theta = [p.detach().clone() for p in net_hnet.theta]
            prev_task_embs = [p.detach().clone() for p in \
                              net_hnet.get_task_embs()]
        else:
            # Compute targets for the regularizer once and keep them all in
            # memory -> Memory expensive.
            targets_C = hreg.get_current_targets(t, net_hnet)
            prev_theta = None
            prev_task_embs = None

    dhandler_class.reset_batch_generator()

    # make copy of network
    if t >= 1:
        net_copy = copy.deepcopy(net)

    # set training_iterations if epochs are set
    if config.epochs == -1:
        training_iterations = config.n_iter
    else:
        assert (config.epochs > 0)
        training_iterations = config.epochs * \
                              int(np.ceil(dhandler_class.num_train_samples / config.batch_size))

    if config.class_incremental:
        training_iterations = int(training_iterations / config.out_dim)

    # Whether we will calculate the regularizer.
    calc_reg = t > 0 and config.class_beta > 0 and config.training_with_hnet

    # set if we want the reg only computed for a subset of the  previous tasks
    if config.hnet_reg_batch_size != -1:
        hnet_reg_batch_size = config.hnet_reg_batch_size
    else:
        hnet_reg_batch_size = None

    for i in range(training_iterations):

        # set optimizer to zero
        optimizer.zero_grad()
        if net_hnet is not None:
            c_emb_optimizer.zero_grad()

        # Get real data
        real_batch = dhandler_class.next_train_batch(config.batch_size)
        X_real = dhandler_class.input_to_torch_tensor(real_batch[0], device,
                                                      mode='train')
        T_real = dhandler_class.output_to_torch_tensor(real_batch[1], device,
                                                       mode='train')

        if i % 100 == 0 and config.show_plots:
            fig_real = _plotImages(X_real, config)
            writer.add_figure('train_class_' + str(t) + '_real',
                              fig_real, global_step=i)

        #################################################
        # Choosing output heads and constructing targets
        ################################################# 

        # If we train a task inference net or class incremental learning we 
        # we construct a target for every single class/task
        if config.class_incremental or config.training_task_infer:
            # in the beginning of training, we look at two output neuron
            task_out = [0, t + 1]
            T_real = torch.zeros((config.batch_size, task_out[1])).to(device)
            T_real[:, task_out[1] - 1] = 1

        elif config.cl_scenario == 1 or config.cl_scenario == 2:
            if config.cl_scenario == 1:
                # take the task specific output neuron
                task_out = [t * config.out_dim, t * config.out_dim + config.out_dim]
            else:
                # always all output neurons, only one head is used
                task_out = [0, config.out_dim]
        else:
            # The number of output neurons is generic and can grow i.e. we
            # do not have to know the number of tasks before we start 
            # learning.
            if not config.infer_output_head:
                task_out = [0, (t + 1) * config.out_dim]
                T_real = torch.cat((torch.zeros((config.batch_size,
                                                 t * config.out_dim)).to(device),
                                    T_real), dim=1)
            # this is a special case where we will infer the task id by another 
            # neural network so we can train on the correct output head direclty
            # and use the infered output head to compute the prediction
            else:
                task_out = [t * config.out_dim, t * config.out_dim + config.out_dim]

        # compute loss of current data
        if config.training_with_hnet:
            weights_c = net_hnet.forward(t)
        else:
            weights_c = None

        Y_hat_logits = net.forward(X_real, weights_c)
        Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]]

        if config.soft_targets:
            soft_label = 0.95
            num_classes = T_real.shape[1]
            soft_targets = torch.where(T_real == 1,
                                       torch.Tensor([soft_label]).to(device),
                                       torch.Tensor([(1 - soft_label) / (num_classes - 1)]).to(device))
            soft_targets = soft_targets.to(device)
            loss_task = Classifier.softmax_and_cross_entropy(Y_hat_logits,
                                                             soft_targets)
        else:
            loss_task = Classifier.softmax_and_cross_entropy(Y_hat_logits, T_real)

        ############################
        # compute loss for fake data
        ############################

        # Get fake data (of all tasks up until now and merge into list)
        if t >= 1 and not config.training_with_hnet:
            fake_loss = get_fake_data_loss(dhandlers_rp, net, dec, d_hnet, device,
                                           config, writer, t, i, net_copy)
            loss_task = (1 - config.l_rew) * loss_task + config.l_rew * fake_loss

        loss_task.backward(retain_graph=calc_reg, create_graph=calc_reg and \
                                                               config.backprop_dt)

        # compute hypernet loss and fix embedding -> change current embs
        if calc_reg:
            if config.no_lookahead:
                dTheta = None
            else:
                dTheta = opstep.calc_delta_theta(optimizer,
                                                 config.use_sgd_change, lr=config.class_lr,
                                                 detach_dt=not config.backprop_dt)
            loss_reg = config.class_beta * hreg.calc_fix_target_reg(net_hnet, t,
                                                                    targets=targets_C, mnet=net, dTheta=dTheta,
                                                                    dTembs=None,
                                                                    prev_theta=prev_theta,
                                                                    prev_task_embs=prev_task_embs,
                                                                    batch_size=hnet_reg_batch_size)
            loss_reg.backward()

        # compute backward passloss_task.backward()
        if not config.dont_train_main_model:
            optimizer.step()

        if net_hnet is not None and config.train_class_embeddings:
            c_emb_optimizer.step()

        # same stats saving
        if i % 50 == 0:
            # compute accuracies for tracking
            Y_hat_logits = net.forward(X_real, weights_c)
            Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]]
            Y_hat = F.softmax(Y_hat_logits, dim=1)
            classifier_accuracy = Classifier.accuracy(Y_hat, T_real) * 100.0
            writer.add_scalar('train/task_%d/class_accuracy' % t,
                              classifier_accuracy, i)
            writer.add_scalar('train/task_%d/loss_task' % t,
                              loss_task, i)
            if t >= 1 and not config.training_with_hnet:
                writer.add_scalar('train/task_%d/fake_loss' % t,
                                  fake_loss, i)

        # plot some gradient statistics
        if i % 200 == 0:
            if not config.dont_train_main_model:
                total_norm = 0
                if config.training_with_hnet:
                    params = net_hnet.theta
                else:
                    params = net.parameters()

                for p in params:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
                total_norm = total_norm ** (1. / 2)
                # TODO write gradient histograms?
                writer.add_scalar('train/task_%d/main_params_grad_norms' % t,
                                  total_norm, i)

            if net_hnet is not None and config.train_class_embeddings:
                total_norm = 0
                for p in [net_hnet.get_task_emb(t)]:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
                total_norm = total_norm ** (1. / 2)
                writer.add_scalar('train/task_%d/hnet_emb_grad_norms' % t,
                                  total_norm, i)

        if i % 200 == 0:
            msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \
                  '(on current training batch).'
            print(msg.format(i, classifier_accuracy))
Esempio n. 3
0
def get_fake_data_loss(dhandlers_rp, net, dec, d_hnet, device, config, writer,
                       t, i, net_copy):
    """ Sample fake data from generator for tasks up to t and compute a loss
    compared to predictions of a checkpointed network.
    
    We must take caution when considering the different learning scenarios
    and methods and training stages, see detailed comments in the code.
    
    In general, we build a batch of replayed data from all previous tasks.
    Since we do not know the labels of the replayed data, we consider the
    output of the checkpointed network as ground thruth i.e. we must compute
    a loss between two logits.See :class:`mnets.classifier_interface.Classifier`
    for a detailed describtion of the different loss functions.
        
    Args:
        (....): See docstring of function :func:`train_tasks`.
        t: Task id.
        i: Current training iteration.
        net_copy: Copy/checkpoint of the classifier network before 
            learning task ``t``.
    Returns:
        The loss between predictions and predictions of a 
        checkpointed network or replayed data.
    
    """

    all_Y_hat_ls = []
    all_targets = []

    # we have to choose from which embeddings (multiple?!) to sample from 
    if config.class_incremental or config.single_class_replay:
        # if we trained every class with a different generator
        emb_num = t * config.out_dim
    else:
        # here samples from the whole task come from one generator
        emb_num = t
    # we have to choose from which embeddings to sample from 

    if config.fake_data_full_range:
        ran = range(0, emb_num)
        bs_per_task = int(np.ceil(config.batch_size / emb_num))
    else:
        random_t = np.random.randint(0, emb_num)
        ran = range(random_t, random_t + 1)
        bs_per_task = config.batch_size

    for re in ran:

        # exchange replay data with real data to compute upper bounds 
        if config.upper_bound:
            real_batch = dhandlers_rp[re].next_train_batch(bs_per_task)
            X_fake = dhandlers_rp[re].input_to_torch_tensor(real_batch[0],
                                                            device, mode='train')
        else:
            # get fake data
            if config.replay_method == 'gan':
                X_fake = sample_gan(dec, d_hnet, config, re, device,
                                    bs=bs_per_task)
            else:
                X_fake = sample_vae(dec, d_hnet, config, re, device,
                                    bs=bs_per_task)

        # save some fake data to the writer
        if i % 100 == 0:
            if X_fake.shape[0] >= 15:
                fig_fake = _plotImages(X_fake, config, bs_per_task)
                writer.add_figure('train_class_' + str(re) + '_fake',
                                  fig_fake, global_step=i)

        # compute soft targets with copied network
        target_logits = net_copy.forward(X_fake).detach()
        Y_hat_ls = net.forward(X_fake.detach())

        ###############
        # BUILD TARGETS
        ###############
        od = config.out_dim

        if config.class_incremental or config.training_task_infer:
            # This is a bit complicated: If we train class/task incrementally
            # we skip thraining the classifier on the first task. 
            # So when starting to train the classifier on task 2, we have to
            # build a hard target for this first output neuron trained by
            # replay data. A soft target (on an untrained output) would not 
            # make sense.

            # output head over all output neurons already available
            task_out = [0, (t + 1) * od]
            # create target with zero everywhere except from the current re
            zeros = torch.zeros(target_logits[:, 0:(t + 1) * od].shape).to(device)

            if config.hard_targets or (t == 1 and re == 0):
                zeros[:, re] = 1
            else:
                zeros[:, 0:t * od] = target_logits[:, 0:t * od]

            targets = zeros
            Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]]

        elif config.cl_scenario == 1 or config.cl_scenario == 2:
            if config.cl_scenario == 1:
                # take the task specific output neuron
                task_out = [re * od, re * od + od]
            else:
                # always all output neurons, only one head is used
                task_out = [0, od]

            Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]]
            target_logits = target_logits[:, task_out[0]:task_out[1]]
            # build hard targets i.e. one hots if this option is chosen
            if config.hard_targets:
                soft_targets = torch.sigmoid(target_logits)
                zeros = torch.zeros(Y_hat_ls.shape).to(device)
                _, argmax = torch.max(soft_targets, 1)
                targets = zeros.scatter_(1, argmax.view(-1, 1), 1)
            else:
                # loss expects logits
                targets = target_logits
        else:
            # take all neurons used up until now

            # output head over all output neurons already available
            task_out = [0, (t + 1) * od]
            # create target with zero everywhere except from the current re
            zeros = torch.zeros(target_logits[:, 0:(t + 1) * od].shape).to(device)

            # sigmoid over the output head(s) from all previous task
            soft_targets = torch.sigmoid(target_logits[:, 0:t * od])

            # compute one hots
            if config.hard_targets:
                _, argmax = torch.max(soft_targets, 1)
                zeros.scatter_(1, argmax.view(-1, 1), 1)
            else:
                # loss expects logits
                zeros[:, 0:t * od] = target_logits[:, 0:t * od]
            targets = zeros
            # choose the correct output size for the actual 
            Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]]

        # add to list
        all_targets.append(targets)
        all_Y_hat_ls.append(Y_hat_ls)

    # cat to one tensor
    all_targets = torch.cat(all_targets)
    Y_hat_ls = torch.cat(all_Y_hat_ls)

    if i % 200 == 0:
        classifier_accuracy = Classifier.accuracy(Y_hat_ls, all_targets) * 100.0
        msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \
              '(on current FAKE DATA training batch).'
        print(msg.format(i, classifier_accuracy))

    # dependent on the target softness, the loss function is chosen
    if config.hard_targets or (config.class_incremental and t == 1):
        return Classifier.logit_cross_entropy_loss(Y_hat_ls, all_targets)
    else:
        return Classifier.knowledge_distillation_loss(Y_hat_ls, all_targets)