Exemplo n.º 1
0
    def distill_loss_fct(config, X, Y_logits, T_soft_logits, data):
        assert np.all(np.equal(X.shape[:2], T_soft_logits.shape[:2]))
        # Note, targets and predictions might have different head sizes if a
        # growing softmax is used.
        assert np.all(np.equal(Y_logits.shape[:2], T_soft_logits.shape[:2]))

        # Disillation temperature.
        T = 2.

        n_digs = config.ssmnist_seq_len

        # Note, smnist samples have the end-of-sequence bit as last
        # timestep, the rest is padded. Since there are `n_digs` digits per
        # sample, we consider the last end-of-sequence digit to determine the
        # unpadded sequence length.
        eod_features = X[:, :, 3].cpu().numpy()
        seq_lengths = np.argsort(eod_features, axis=0)[-n_digs:].max(axis=0)
        inds = seq_lengths - 1
        inds[inds < 0] = 0

        # Only compute loss for last timestep.
        Y_logits = Y_logits[inds, np.arange(inds.size), :]
        T_soft_logits = T_soft_logits[inds, np.arange(inds.size), :]

        target_mapping = None
        if config.all_task_softmax:
            target_mapping = list(range(T_soft_logits.shape[1]))

        return Classifier.knowledge_distillation_loss(
            Y_logits,
            T_soft_logits,
            target_mapping=target_mapping,
            device=Y_logits.device,
            T=T)
Exemplo n.º 2
0
    def distill_loss_fct(config,
                         X,
                         Y_logits,
                         T_soft_logits,
                         data,
                         in_seq_lens=None):
        if in_seq_lens is None:
            raise NotImplementedError(
                'This distillation loss is currently ' +
                'only implemented if sequence lengths are provided, as they ' +
                'can\'t be inferred easily.')
        # Note, input and output sequence lengths are identical for the PoS
        # dataset.

        assert np.all(np.equal(X.shape[:2], T_soft_logits.shape[:2]))
        # Note, targets and predictions might have different head sizes if a
        # growing softmax is used.
        assert np.all(np.equal(Y_logits.shape[:2], T_soft_logits.shape[:2]))

        # Disillation temperature.
        T = 2.

        target_mapping = None
        if config.all_task_softmax:
            target_mapping = list(range(T_soft_logits.shape[2]))

        dloss = 0
        total_num_ts = 0

        for bid in range(X.shape[1]):
            sl = int(in_seq_lens[bid])
            total_num_ts += sl

            Y_logits_i = Y_logits[:sl, bid, :]
            T_soft_logits_i = T_soft_logits[:sl, bid, :]

            dloss += Classifier.knowledge_distillation_loss(
                Y_logits_i,
                T_soft_logits_i,
                target_mapping=target_mapping,
                device=Y_logits.device,
                T=T) * sl

        return dloss / total_num_ts
Exemplo n.º 3
0
    def distill_loss_fct(config, X, Y_logits, T_soft_logits, data):
        # Note, targets and predictions might have different head sizes if a
        # growing softmax is used.
        assert np.all(np.equal(Y_logits.shape[:2], T_soft_logits.shape[:2]))

        # Only compute loss for last timestep.
        Y_logits = Y_logits[-1, :, :]
        T_soft_logits = T_soft_logits[-1, :, :]

        target_mapping = None
        if config.all_task_softmax:
            target_mapping = list(range(T_soft_logits.shape[1]))

        return Classifier.knowledge_distillation_loss(
            Y_logits,
            T_soft_logits,
            target_mapping=target_mapping,
            device=Y_logits.device,
            T=2.)
    def distill_loss_fct(config, X, Y_logits, T_soft_logits, data):
        assert np.all(np.equal(X.shape[:2], T_soft_logits.shape[:2]))
        # Note, targets and predictions might have different head sizes if a
        # growing softmax is used.
        assert np.all(np.equal(Y_logits.shape[:2], T_soft_logits.shape[:2]))

        # Disillation temperature.
        T = 2.

        if config.distill_across_time:
            # Emd-of-digit feature.
            # Note, the softmax would be an easy way to obtain per-timestep
            # probabilities, that a digit has ended. But since this feature
            # vector `X[:, :, 3]` is ideally a 1-hot encoding, it shouldn't
            # be squashed via a softmax, that would blur out the probabilities.
            #ts_weights = F.softmax(X[:, :, 3], dim=0).detach()
            ts_weights = X[:, :, 3].clone()
            ts_weights[ts_weights < 0] = 0
            # Avoid division by zero in case all elements of `X[:, :, 3]` are
            # negative.
            ts_weights /= ts_weights.sum() + 1e-5
            ts_weights = ts_weights.detach()

            # For distillation, we use a tempered softmax.
            T_soft = F.softmax(T_soft_logits / T, dim=2)
            if config.all_task_softmax and Y_logits.shape[2] != T_soft.shape[2]:
                # Pad new classes to soft targets.
                T_soft = F.pad(T_soft, (0, data.num_classes),
                               mode='constant',
                               value=0)
                assert Y_logits.shape[2] == T_soft.shape[2]

            # Distillation loss.
            loss = -(T_soft * F.log_softmax(Y_logits / T, dim=2)).sum(dim=2) * \
                T**2
            loss *= ts_weights

            # Sum across time (note, weights sum to 1) and mean across batch
            # dimension.
            return loss.sum(dim=0).mean()
        else:
            # Note, smnist samples have the end-of-sequence bit as last
            # timestep, the rest is padded.
            seq_lengths = X[:, :, 3].argmax(dim=0)
            inds = seq_lengths.cpu().numpy() - 1
            inds[inds < 0] = 0

            # Only compute loss for last timestep.
            Y_logits = Y_logits[inds, np.arange(inds.size), :]
            T_soft_logits = T_soft_logits[inds, np.arange(inds.size), :]

            target_mapping = None
            if config.all_task_softmax:
                target_mapping = list(range(T_soft_logits.shape[1]))

            return Classifier.knowledge_distillation_loss(
                Y_logits,
                T_soft_logits,
                target_mapping=target_mapping,
                device=Y_logits.device,
                T=T)
Exemplo n.º 5
0
def get_fake_data_loss(dhandlers_rp, net, dec, d_hnet, device, config, writer,
                       t, i, net_copy):
    """ Sample fake data from generator for tasks up to t and compute a loss
    compared to predictions of a checkpointed network.
    
    We must take caution when considering the different learning scenarios
    and methods and training stages, see detailed comments in the code.
    
    In general, we build a batch of replayed data from all previous tasks.
    Since we do not know the labels of the replayed data, we consider the
    output of the checkpointed network as ground thruth i.e. we must compute
    a loss between two logits.See :class:`mnets.classifier_interface.Classifier`
    for a detailed describtion of the different loss functions.
        
    Args:
        (....): See docstring of function :func:`train_tasks`.
        t: Task id.
        i: Current training iteration.
        net_copy: Copy/checkpoint of the classifier network before 
            learning task ``t``.
    Returns:
        The loss between predictions and predictions of a 
        checkpointed network or replayed data.
    
    """

    all_Y_hat_ls = []
    all_targets = []

    # we have to choose from which embeddings (multiple?!) to sample from 
    if config.class_incremental or config.single_class_replay:
        # if we trained every class with a different generator
        emb_num = t * config.out_dim
    else:
        # here samples from the whole task come from one generator
        emb_num = t
    # we have to choose from which embeddings to sample from 

    if config.fake_data_full_range:
        ran = range(0, emb_num)
        bs_per_task = int(np.ceil(config.batch_size / emb_num))
    else:
        random_t = np.random.randint(0, emb_num)
        ran = range(random_t, random_t + 1)
        bs_per_task = config.batch_size

    for re in ran:

        # exchange replay data with real data to compute upper bounds 
        if config.upper_bound:
            real_batch = dhandlers_rp[re].next_train_batch(bs_per_task)
            X_fake = dhandlers_rp[re].input_to_torch_tensor(real_batch[0],
                                                            device, mode='train')
        else:
            # get fake data
            if config.replay_method == 'gan':
                X_fake = sample_gan(dec, d_hnet, config, re, device,
                                    bs=bs_per_task)
            else:
                X_fake = sample_vae(dec, d_hnet, config, re, device,
                                    bs=bs_per_task)

        # save some fake data to the writer
        if i % 100 == 0:
            if X_fake.shape[0] >= 15:
                fig_fake = _plotImages(X_fake, config, bs_per_task)
                writer.add_figure('train_class_' + str(re) + '_fake',
                                  fig_fake, global_step=i)

        # compute soft targets with copied network
        target_logits = net_copy.forward(X_fake).detach()
        Y_hat_ls = net.forward(X_fake.detach())

        ###############
        # BUILD TARGETS
        ###############
        od = config.out_dim

        if config.class_incremental or config.training_task_infer:
            # This is a bit complicated: If we train class/task incrementally
            # we skip thraining the classifier on the first task. 
            # So when starting to train the classifier on task 2, we have to
            # build a hard target for this first output neuron trained by
            # replay data. A soft target (on an untrained output) would not 
            # make sense.

            # output head over all output neurons already available
            task_out = [0, (t + 1) * od]
            # create target with zero everywhere except from the current re
            zeros = torch.zeros(target_logits[:, 0:(t + 1) * od].shape).to(device)

            if config.hard_targets or (t == 1 and re == 0):
                zeros[:, re] = 1
            else:
                zeros[:, 0:t * od] = target_logits[:, 0:t * od]

            targets = zeros
            Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]]

        elif config.cl_scenario == 1 or config.cl_scenario == 2:
            if config.cl_scenario == 1:
                # take the task specific output neuron
                task_out = [re * od, re * od + od]
            else:
                # always all output neurons, only one head is used
                task_out = [0, od]

            Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]]
            target_logits = target_logits[:, task_out[0]:task_out[1]]
            # build hard targets i.e. one hots if this option is chosen
            if config.hard_targets:
                soft_targets = torch.sigmoid(target_logits)
                zeros = torch.zeros(Y_hat_ls.shape).to(device)
                _, argmax = torch.max(soft_targets, 1)
                targets = zeros.scatter_(1, argmax.view(-1, 1), 1)
            else:
                # loss expects logits
                targets = target_logits
        else:
            # take all neurons used up until now

            # output head over all output neurons already available
            task_out = [0, (t + 1) * od]
            # create target with zero everywhere except from the current re
            zeros = torch.zeros(target_logits[:, 0:(t + 1) * od].shape).to(device)

            # sigmoid over the output head(s) from all previous task
            soft_targets = torch.sigmoid(target_logits[:, 0:t * od])

            # compute one hots
            if config.hard_targets:
                _, argmax = torch.max(soft_targets, 1)
                zeros.scatter_(1, argmax.view(-1, 1), 1)
            else:
                # loss expects logits
                zeros[:, 0:t * od] = target_logits[:, 0:t * od]
            targets = zeros
            # choose the correct output size for the actual 
            Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]]

        # add to list
        all_targets.append(targets)
        all_Y_hat_ls.append(Y_hat_ls)

    # cat to one tensor
    all_targets = torch.cat(all_targets)
    Y_hat_ls = torch.cat(all_Y_hat_ls)

    if i % 200 == 0:
        classifier_accuracy = Classifier.accuracy(Y_hat_ls, all_targets) * 100.0
        msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \
              '(on current FAKE DATA training batch).'
        print(msg.format(i, classifier_accuracy))

    # dependent on the target softness, the loss function is chosen
    if config.hard_targets or (config.class_incremental and t == 1):
        return Classifier.logit_cross_entropy_loss(Y_hat_ls, all_targets)
    else:
        return Classifier.knowledge_distillation_loss(Y_hat_ls, all_targets)
Exemplo n.º 6
0
def get_fake_data_loss(dhandlers_rp, net, dec, d_hnet, device, config, writer,
                       t, i, net_copy):
    """ Sample fake data from generator for tasks up to t and compute a loss
    compared to predictions of a checkpointed network.
    
    We must take caution when considering the different learning scenarios
    and methods and training stages, see detailed comments in the code.
    
    In general, we build a batch of replayed data from all previous tasks.
    Since we do not know the labels of the replayed data, we consider the
    output of the checkpointed network as ground thruth i.e. we must compute
    a loss between two logits.See :class:`mnets.classifier_interface.Classifier`
    for a detailed describtion of the different loss functions.
        
    Args:
        (....): See docstring of function :func:`train_tasks`.
        t: Task id.
        i: Current training iteration.
        net_copy: Copy/checkpoint of the classifier network before 
            learning task ``t``.
    Returns:
        The loss between predictions and predictions of a 
        checkpointed network or replayed data.
    
    """

    all_Y_hat_ls = []
    all_targets = []

    # we have to choose from which embeddings (multiple?!) to sample from
    if config.class_incremental or config.single_class_replay:
        # if we trained every class with a different generator
        emb_num = t * config.out_dim
    else:
        # here samples from the whole task come from one generator
        emb_num = t
    # we have to choose from which embeddings to sample from

    if config.fake_data_full_range:
        ran = range(0, emb_num)
        bs_per_task = int(np.ceil(config.batch_size / emb_num))
    else:
        random_t = np.random.randint(0, emb_num)
        ran = range(random_t, random_t + 1)
        bs_per_task = config.batch_size

    # print('config.upper_bound: ',config.upper_bound)
    # print('config.num_embeddings: ',config.num_embeddings)

    for re in ran:

        # exchange replay data with real data to compute upper bounds
        if config.upper_bound:
            real_batch = dhandlers_rp[re].next_train_batch(bs_per_task)  #15
            X_fake = dhandlers_rp[re].input_to_torch_tensor(
                real_batch[0], device, mode='train')  #each batch 128
        else:
            # get fake data
            if config.replay_method == 'gan':
                X_fake = sample_gan(dec,
                                    d_hnet,
                                    config,
                                    re,
                                    device,
                                    bs=bs_per_task)
            else:
                X_fake = sample_vae(dec,
                                    d_hnet,
                                    config,
                                    re,
                                    device,
                                    bs=bs_per_task)
        # print('X_fake: ',X_fake.size())

        # save some fake data to the writer
        # if i % 100 == 0:
        #     if X_fake.shape[0] >= 15:
        #         fig_fake = _plotImages(X_fake, config, bs_per_task)
        #         writer.add_figure('train_class_' + str(re) + '_fake',
        #                                         fig_fake, global_step=i)

        # compute soft targets with copied network
        target_logits = net_copy.forward(X_fake).detach()
        Y_hat_ls = net.forward(X_fake.detach())

        ###############
        # BUILD TARGETS
        ###############
        if config.cl_scenario == 1:
            # take the task specific output neuron
            task_out = [sum(config.dims[:re]), sum(config.dims[:re + 1])]

        Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]]
        target_logits = target_logits[:, task_out[0]:task_out[1]]

        # build hard targets i.e. one hots if this option is chosen
        if config.hard_targets:
            soft_targets = torch.sigmoid(target_logits)
            zeros = torch.zeros(Y_hat_ls.shape).to(device)
            _, argmax = torch.max(soft_targets, 1)
            targets = zeros.scatter_(1, argmax.view(-1, 1), 1)
        else:
            # loss expects logits
            targets = target_logits

        # add to list
        all_targets.append(targets)
        all_Y_hat_ls.append(Y_hat_ls)

    # cat to one tensor

    # all_targets = torch.cat(all_targets)
    # Y_hat_ls = torch.cat(all_Y_hat_ls)

    all_targets = all_targets
    Y_hat_ls = all_Y_hat_ls

    if i % 200 == 0:
        classifier_accuracy = Classifier.accuracy(Y_hat_ls,
                                                  all_targets) * 100.0
        msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \
            '(on current FAKE DATA training batch).'
        print(msg.format(i, classifier_accuracy))

    # dependent on the target softness, the loss function is chosen
    if config.hard_targets or (config.class_incremental and t == 1):
        return Classifier.logit_cross_entropy_loss(Y_hat_ls, all_targets)
    else:
        return Classifier.knowledge_distillation_loss(Y_hat_ls, all_targets)