def test(task_id, data, mnet, hnet, device, shared, config, writer, logger, train_iter=None, task_emb=None, cl_scenario=None, test_size=None): """Evaluate the current performance using the test set. Note: The hypernetwork ``hnet`` may be ``None``, in which case it is assumed that the main network ``mnet`` has internal weights. Args: (....): See docstring of function :func:`train`. train_iter (int, optional): The current training iteration. If given, it is used for tensorboard logging. task_emb (torch.Tensor, optional): Task embedding. If given, no task ID will be provided to the hypernetwork. This might be useful if the performance of other than the trained task embeddings should be tested. .. note:: This option may only be used for ``cl_scenario=1``. It doesn't make sense if the task ID has to be inferred. cl_scenario (int, optional): In case the system should be tested on another CL scenario than the one user-defined in ``config``. .. note:: It is up to the user to ensure that the CL scnearios are compatible in this implementation. test_size (int, optional): In case the testing shouldn't be performed on the entire test set, this option can be used to specify the number of test samples to be used. Returns: (tuple): Tuple containing: - **test_acc**: Test accuracy on classification task. - **task_acc**: Task prediction accuracy (always 100% for **CL1**). """ if cl_scenario is None: cl_scenario = config.cl_scenario else: assert cl_scenario in [1, 2, 3] # `task_emb` ignored for other cl scenarios! assert task_emb is None or cl_scenario == 1, \ '"task_emb" may only be specified for CL1, as we infer the ' + \ 'embedding for other scenarios.' mnet.eval() if hnet is not None: hnet.eval() if train_iter is None: logger.info('### Test run ...') else: logger.info('# Testing network before running training step %d ...' % \ train_iter) # We need to tell the main network, which batch statistics to use, in case # batchnorm is used and we checkpoint the batchnorm stats. mnet_kwargs = {} if mnet.batchnorm_layers is not None: if config.bn_distill_stats: raise NotImplementedError() elif not config.bn_no_running_stats and \ not config.bn_no_stats_checkpointing: # Specify current task as condition to select correct # running stats. mnet_kwargs['condition'] = task_id if task_emb is not None: # NOTE `task_emb` might have nothing to do with `task_id`. logger.warning('Using batch statistics accumulated for task ' + '%d for batchnorm, but testing is ' % task_id + 'performed using a given task embedding.') with torch.no_grad(): batch_size = config.val_batch_size # FIXME Assuming all output heads have the same size. n_head = data.num_classes if test_size is None or test_size >= data.num_test_samples: test_size = data.num_test_samples else: # Make sure that we always use the same test samples. data.reset_batch_generator(train=False, test=True, val=False) logger.info('Note, only part of test set is used for this test ' + 'run!') test_loss = 0.0 # We store all predicted labels and tasks while going over individual # test batches. correct_labels = np.empty(test_size, np.int) pred_labels = np.empty(test_size, np.int) correct_tasks = np.ones(test_size, np.int) * task_id pred_tasks = np.empty(test_size, np.int) curr_bs = batch_size N_processed = 0 # Sweep through the test set. while N_processed < test_size: if N_processed + curr_bs > test_size: curr_bs = test_size - N_processed N_processed += curr_bs batch = data.next_test_batch(curr_bs) X = data.input_to_torch_tensor(batch[0], device) T = data.output_to_torch_tensor(batch[1], device) ############################ ### Get main net weights ### ############################ if hnet is None: weights = None elif cl_scenario > 1: raise NotImplementedError() elif task_emb is not None: weights = hnet.forward(task_emb=task_emb) else: weights = hnet.forward(task_id=task_id) ####################### ### Get predictions ### ####################### Y_hat_logits = mnet.forward(X, weights=weights, **mnet_kwargs) if config.cl_scenario == 1: # Select current head. task_out = [task_id * n_head, (task_id + 1) * n_head] elif config.cl_scenario == 2: # Only 1 output head. task_out = [0, n_head] else: raise NotImplementedError() # TODO Choose the predicted output head per sample. #task_out = [predicted_task_id[0]*n_head, # (predicted_task_id[0]+1)*n_head] Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]] # We take the softmax after the output neurons are chosen. Y_hat = F.softmax(Y_hat_logits, dim=1).cpu().numpy() correct_labels[N_processed-curr_bs:N_processed] = \ T.argmax(dim=1, keepdim=False).cpu().numpy() pred_labels[N_processed-curr_bs:N_processed] = \ Y_hat.argmax(axis=1) # Set task prediction to 100% if we do not infer it. if cl_scenario > 1: raise NotImplementedError() #pred_tasks[N_processed-curr_bs:N_processed] = \ # predicted_task_id.cpu().numpy() else: pred_tasks[N_processed - curr_bs:N_processed] = task_id # Note, targets are 1-hot encoded. test_loss += Classifier.logit_cross_entropy_loss(Y_hat_logits, T, reduction='sum') print('test Y_hat: ', Y_hat.argmax(axis=1)) print('test T: ', T.argmax(dim=1)) print('test len: ', T.argmax(dim=1).size()) print('N_processed: ', N_processed) class_n_correct = (correct_labels == pred_labels).sum() test_acc = 100.0 * class_n_correct / test_size task_n_correct = (correct_tasks == pred_tasks).sum() task_acc = 100.0 * task_n_correct / test_size test_loss /= test_size msg = '### Test accuracy of task %d' % (task_id+1) \ + (' (before training iteration %d)' % train_iter if \ train_iter is not None else '') \ + ': %.3f' % (test_acc) \ + (' (using a given task embedding)' if task_emb is not None \ else '') \ + (' - task prediction accuracy: %.3f' % task_acc if \ cl_scenario > 1 else '') logger.info(msg) if train_iter is not None: writer.add_scalar('test/task_%d/class_accuracy' % task_id, test_acc, train_iter) if config.cl_scenario > 1: writer.add_scalar('test/task_%d/task_pred_accuracy' % \ task_id, task_acc, train_iter) return test_acc, task_acc
def get_fake_data_loss(dhandlers_rp, net, dec, d_hnet, device, config, writer, t, i, net_copy): """ Sample fake data from generator for tasks up to t and compute a loss compared to predictions of a checkpointed network. We must take caution when considering the different learning scenarios and methods and training stages, see detailed comments in the code. In general, we build a batch of replayed data from all previous tasks. Since we do not know the labels of the replayed data, we consider the output of the checkpointed network as ground thruth i.e. we must compute a loss between two logits.See :class:`mnets.classifier_interface.Classifier` for a detailed describtion of the different loss functions. Args: (....): See docstring of function :func:`train_tasks`. t: Task id. i: Current training iteration. net_copy: Copy/checkpoint of the classifier network before learning task ``t``. Returns: The loss between predictions and predictions of a checkpointed network or replayed data. """ all_Y_hat_ls = [] all_targets = [] # we have to choose from which embeddings (multiple?!) to sample from if config.class_incremental or config.single_class_replay: # if we trained every class with a different generator emb_num = t * config.out_dim else: # here samples from the whole task come from one generator emb_num = t # we have to choose from which embeddings to sample from if config.fake_data_full_range: ran = range(0, emb_num) bs_per_task = int(np.ceil(config.batch_size / emb_num)) else: random_t = np.random.randint(0, emb_num) ran = range(random_t, random_t + 1) bs_per_task = config.batch_size for re in ran: # exchange replay data with real data to compute upper bounds if config.upper_bound: real_batch = dhandlers_rp[re].next_train_batch(bs_per_task) X_fake = dhandlers_rp[re].input_to_torch_tensor(real_batch[0], device, mode='train') else: # get fake data if config.replay_method == 'gan': X_fake = sample_gan(dec, d_hnet, config, re, device, bs=bs_per_task) else: X_fake = sample_vae(dec, d_hnet, config, re, device, bs=bs_per_task) # save some fake data to the writer if i % 100 == 0: if X_fake.shape[0] >= 15: fig_fake = _plotImages(X_fake, config, bs_per_task) writer.add_figure('train_class_' + str(re) + '_fake', fig_fake, global_step=i) # compute soft targets with copied network target_logits = net_copy.forward(X_fake).detach() Y_hat_ls = net.forward(X_fake.detach()) ############### # BUILD TARGETS ############### od = config.out_dim if config.class_incremental or config.training_task_infer: # This is a bit complicated: If we train class/task incrementally # we skip thraining the classifier on the first task. # So when starting to train the classifier on task 2, we have to # build a hard target for this first output neuron trained by # replay data. A soft target (on an untrained output) would not # make sense. # output head over all output neurons already available task_out = [0, (t + 1) * od] # create target with zero everywhere except from the current re zeros = torch.zeros(target_logits[:, 0:(t + 1) * od].shape).to(device) if config.hard_targets or (t == 1 and re == 0): zeros[:, re] = 1 else: zeros[:, 0:t * od] = target_logits[:, 0:t * od] targets = zeros Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]] elif config.cl_scenario == 1 or config.cl_scenario == 2: if config.cl_scenario == 1: # take the task specific output neuron task_out = [re * od, re * od + od] else: # always all output neurons, only one head is used task_out = [0, od] Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]] target_logits = target_logits[:, task_out[0]:task_out[1]] # build hard targets i.e. one hots if this option is chosen if config.hard_targets: soft_targets = torch.sigmoid(target_logits) zeros = torch.zeros(Y_hat_ls.shape).to(device) _, argmax = torch.max(soft_targets, 1) targets = zeros.scatter_(1, argmax.view(-1, 1), 1) else: # loss expects logits targets = target_logits else: # take all neurons used up until now # output head over all output neurons already available task_out = [0, (t + 1) * od] # create target with zero everywhere except from the current re zeros = torch.zeros(target_logits[:, 0:(t + 1) * od].shape).to(device) # sigmoid over the output head(s) from all previous task soft_targets = torch.sigmoid(target_logits[:, 0:t * od]) # compute one hots if config.hard_targets: _, argmax = torch.max(soft_targets, 1) zeros.scatter_(1, argmax.view(-1, 1), 1) else: # loss expects logits zeros[:, 0:t * od] = target_logits[:, 0:t * od] targets = zeros # choose the correct output size for the actual Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]] # add to list all_targets.append(targets) all_Y_hat_ls.append(Y_hat_ls) # cat to one tensor all_targets = torch.cat(all_targets) Y_hat_ls = torch.cat(all_Y_hat_ls) if i % 200 == 0: classifier_accuracy = Classifier.accuracy(Y_hat_ls, all_targets) * 100.0 msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \ '(on current FAKE DATA training batch).' print(msg.format(i, classifier_accuracy)) # dependent on the target softness, the loss function is chosen if config.hard_targets or (config.class_incremental and t == 1): return Classifier.logit_cross_entropy_loss(Y_hat_ls, all_targets) else: return Classifier.knowledge_distillation_loss(Y_hat_ls, all_targets)
def train(task_id, data, mnet, hnet, device, config, shared, writer, logger): """Train the hyper network using the task-specific loss plus a regularizer that should overcome catastrophic forgetting. :code:`loss = task_loss + beta * regularizer`. Args: task_id: The index of the task on which we train. data: The dataset handler. mnet: The model of the main network. hnet: The model of the hyper network. May be ``None``. device: Torch device (cpu or gpu). config: The command line arguments. shared (argparse.Namespace): Set of variables shared between functions. writer: The tensorboard summary writer. logger: The logger that should be used rather than the print method. """ start_time = time() print('data: ', data) print('data.num_classes: ', data.num_classes) print('data.num_train_samples: ', data.num_train_samples) logger.info('Training network ...') mnet.train() if hnet is not None: hnet.train() ################# ### Optimizer ### ################# # Define the optimizers used to train main network and hypernet. if hnet is not None: theta_params = list(hnet.theta) if config.continue_emb_training: for i in range(task_id): # for all previous task embeddings theta_params.append(hnet.get_task_emb(i)) # Only for the current task embedding. # Important that this embedding is in a different optimizer in case # we use the lookahead. emb_optimizer = get_optimizer([hnet.get_task_emb(task_id)], config.lr, momentum=config.momentum, weight_decay=config.weight_decay, use_adam=config.use_adam, adam_beta1=config.adam_beta1, use_rmsprop=config.use_rmsprop) else: theta_params = mnet.weights emb_optimizer = None theta_optimizer = get_optimizer(theta_params, config.lr, momentum=config.momentum, weight_decay=config.weight_decay, use_adam=config.use_adam, adam_beta1=config.adam_beta1, use_rmsprop=config.use_rmsprop) ################################ ### Learning rate schedulers ### ################################ if config.plateau_lr_scheduler: assert (config.epochs != -1) # The scheduler config has been taken from here: # https://keras.io/examples/cifar10_resnet/ # Note, we use 'max' instead of 'min' as we look at accuracy rather # than validation loss! plateau_scheduler_theta = optim.lr_scheduler.ReduceLROnPlateau( \ theta_optimizer, 'max', factor=np.sqrt(0.1), patience=5, min_lr=0.5e-6, cooldown=0) plateau_scheduler_emb = None if emb_optimizer is not None: plateau_scheduler_emb = optim.lr_scheduler.ReduceLROnPlateau( \ emb_optimizer, 'max', factor=np.sqrt(0.1), patience=5, min_lr=0.5e-6, cooldown=0) if config.lambda_lr_scheduler: assert (config.epochs != -1) def lambda_lr(epoch): """Multiplicative Factor for Learning Rate Schedule. Computes a multiplicative factor for the initial learning rate based on the current epoch. This method can be used as argument ``lr_lambda`` of class :class:`torch.optim.lr_scheduler.LambdaLR`. The schedule is inspired by the Resnet CIFAR-10 schedule suggested here https://keras.io/examples/cifar10_resnet/. Args: epoch (int): The number of epochs Returns: lr_scale (float32): learning rate scale """ lr_scale = 1. if epoch > 180: lr_scale = 0.5e-3 elif epoch > 160: lr_scale = 1e-3 elif epoch > 120: lr_scale = 1e-2 elif epoch > 80: lr_scale = 1e-1 return lr_scale lambda_scheduler_theta = optim.lr_scheduler.LambdaLR( theta_optimizer, lambda_lr) lambda_scheduler_emb = None if emb_optimizer is not None: lambda_scheduler_emb = optim.lr_scheduler.LambdaLR( emb_optimizer, lambda_lr) ############################## ### Prepare CL Regularizer ### ############################## # Whether we will calculate the regularizer. calc_reg = task_id > 0 and not config.mnet_only and config.beta > 0 and \ not config.train_from_scratch # Compute targets when the reg is activated and we are not training # the first task if calc_reg: if config.online_target_computation: # Compute targets for the regularizer whenever they are needed. # -> Computationally expensive. targets_hypernet = None prev_theta = [p.detach().clone() for p in hnet.theta] prev_task_embs = [p.detach().clone() for p in hnet.get_task_embs()] else: # Compute targets for the regularizer once and keep them all in # memory -> Memory expensive. targets_hypernet = hreg.get_current_targets(task_id, hnet) prev_theta = None prev_task_embs = None # If we do not want to regularize all outputs (in a multi-head setup). # Note, we don't care whether output heads other than the current one # change. regged_outputs = None if config.cl_scenario != 2: # FIXME We assume here that all tasks have the same output size. n_y = data.num_classes regged_outputs = [ list(range(i * n_y, (i + 1) * n_y)) for i in range(task_id) ] # We need to tell the main network, which batch statistics to use, in case # batchnorm is used and we checkpoint the batchnorm stats. mnet_kwargs = {} if mnet.batchnorm_layers is not None: if config.bn_distill_stats: raise NotImplementedError() elif not config.bn_no_running_stats and \ not config.bn_no_stats_checkpointing: # Specify current task as condition to select correct # running stats. mnet_kwargs['condition'] = task_id ###################### ### Start training ### ###################### iter_per_epoch = -1 if config.epochs == -1: training_iterations = config.n_iter else: assert (config.epochs > 0) iter_per_epoch = int(np.ceil(data.num_train_samples / \ config.batch_size)) training_iterations = config.epochs * iter_per_epoch summed_iter_runtime = 0 for i in range(training_iterations): ### Evaluate network. # We test the network before we run the training iteration. # That way, we can see the initial performance of the untrained network. if i % config.val_iter == 0: test(task_id, data, mnet, hnet, device, shared, config, writer, logger, train_iter=i) mnet.train() if hnet is not None: hnet.train() if i % 200 == 0: logger.info('Training step: %d ...' % i) iter_start_time = time() theta_optimizer.zero_grad() if emb_optimizer is not None: emb_optimizer.zero_grad() ####################################### ### Data for current task and batch ### ####################################### batch = data.next_train_batch(config.batch_size) X = data.input_to_torch_tensor(batch[0], device, mode='train') T = data.output_to_torch_tensor(batch[1], device, mode='train') # Get the output neurons depending on the continual learning scenario. n_y = data.num_classes if config.cl_scenario == 1: # Choose current head. task_out = [task_id * n_y, (task_id + 1) * n_y] elif config.cl_scenario == 2: # Always all output neurons, only one head is used. task_out = [0, n_y] else: # Choose current head, which will be inferred during inference. task_out = [task_id * n_y, (task_id + 1) * n_y] ######################## ### Loss computation ### ######################## if config.mnet_only: weights = None else: weights = hnet.forward(task_id=task_id) Y_hat_logits = mnet.forward(X, weights, **mnet_kwargs) # Restrict output neurons Y_hat_logits = Y_hat_logits[:, task_out[0]:task_out[1]] assert (T.shape[1] == Y_hat_logits.shape[1]) # compute loss on task and compute gradients if config.soft_targets: soft_label = 0.95 num_classes = data.num_classes soft_targets = torch.where( T == 1, torch.Tensor([soft_label]), torch.Tensor([(1 - soft_label) / (num_classes - 1)])) soft_targets = soft_targets.to(device) loss_task = Classifier.softmax_and_cross_entropy( Y_hat_logits, soft_targets) else: loss_task = Classifier.logit_cross_entropy_loss(Y_hat_logits, T) # Compute gradients based on task loss (those might be used in the CL # regularizer). loss_task.backward(retain_graph=calc_reg, create_graph=calc_reg and \ config.backprop_dt) # The current task embedding only depends in the task loss, so we can # update it already. if emb_optimizer is not None: emb_optimizer.step() ############################# ### CL (HNET) Regularizer ### ############################# loss_reg = 0 dTheta = None if calc_reg: if config.no_lookahead: dTembs = None dTheta = None else: dTheta = opstep.calc_delta_theta( theta_optimizer, False, lr=config.lr, detach_dt=not config.backprop_dt) if config.continue_emb_training: dTembs = dTheta[-task_id:] dTheta = dTheta[:-task_id] else: dTembs = None loss_reg = hreg.calc_fix_target_reg( hnet, task_id, targets=targets_hypernet, dTheta=dTheta, dTembs=dTembs, mnet=mnet, inds_of_out_heads=regged_outputs, prev_theta=prev_theta, prev_task_embs=prev_task_embs, batch_size=config.cl_reg_batch_size) loss_reg *= config.beta loss_reg.backward() # Now, that we computed the regularizer, we can use the accumulated # gradients and update the hnet (or mnet) parameters. theta_optimizer.step() Y_hat = F.softmax(Y_hat_logits, dim=1) classifier_accuracy = Classifier.accuracy(Y_hat, T) * 100.0 # print('train T: ',Y_hat.argmax(dim=1, keepdim=False)) # print('train T: ',T.argmax(dim=1, keepdim=False)) # print('train Y_hat: ',Y_hat.size()) # print('train T: ',T.size()) ######################### # Learning rate scheduler ######################### if config.plateau_lr_scheduler: assert (iter_per_epoch != -1) if i % iter_per_epoch == 0 and i > 0: curr_epoch = i // iter_per_epoch logger.info('Computing test accuracy for plateau LR ' + 'scheduler (epoch %d).' % curr_epoch) # We need a validation quantity for the plateau LR scheduler. # FIXME we should use an actual validation set rather than the # test set. # Note, https://keras.io/examples/cifar10_resnet/ uses the test # set to compute the validation loss. We use the "validation" # accuracy instead. # FIXME We increase `train_iter` as the print messages in the # test method suggest that the testing has been executed before test_acc, _ = test(task_id, data, mnet, hnet, device, shared, config, writer, logger, train_iter=i + 1) mnet.train() if hnet is not None: hnet.train() plateau_scheduler_theta.step(test_acc) if plateau_scheduler_emb is not None: plateau_scheduler_emb.step(test_acc) if config.lambda_lr_scheduler: assert (iter_per_epoch != -1) if i % iter_per_epoch == 0 and i > 0: curr_epoch = i // iter_per_epoch logger.info('Applying Lambda LR scheduler (epoch %d).' % curr_epoch) lambda_scheduler_theta.step() if lambda_scheduler_emb is not None: lambda_scheduler_emb.step() ########################### ### Tensorboard summary ### ########################### # We don't wanna slow down training by having too much output. if i % 50 == 0: writer.add_scalar('train/task_%d/class_accuracy' % task_id, classifier_accuracy, i) writer.add_scalar('train/task_%d/loss_task' % task_id, loss_task, i) writer.add_scalar('train/task_%d/loss_reg' % task_id, loss_reg, i) ### Show the current training progress to the user. if i % config.val_iter == 0: msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \ '(on current training batch).' logger.debug(msg.format(i, classifier_accuracy)) iter_end_time = time() summed_iter_runtime += (iter_end_time - iter_start_time) if i % 200 == 0: logger.info('Training step: %d ... Done -- (runtime: %f sec)' % \ (i, iter_end_time - iter_start_time)) if mnet.batchnorm_layers is not None: if not config.bn_distill_stats and \ not config.bn_no_running_stats and \ not config.bn_no_stats_checkpointing: # Checkpoint the current running statistics (that have been # estimated while training the current task). for bn_layer in mnet.batchnorm_layers: assert (bn_layer.num_stats == task_id + 1) bn_layer.checkpoint_stats() avg_iter_time = summed_iter_runtime / config.n_iter logger.info('Average runtime per training iteration: %f sec.' % \ avg_iter_time) logger.info('Elapsed time for training task %d: %f sec.' % \ (task_id+1, time()-start_time))
def get_fake_data_loss(dhandlers_rp, net, dec, d_hnet, device, config, writer, t, i, net_copy): """ Sample fake data from generator for tasks up to t and compute a loss compared to predictions of a checkpointed network. We must take caution when considering the different learning scenarios and methods and training stages, see detailed comments in the code. In general, we build a batch of replayed data from all previous tasks. Since we do not know the labels of the replayed data, we consider the output of the checkpointed network as ground thruth i.e. we must compute a loss between two logits.See :class:`mnets.classifier_interface.Classifier` for a detailed describtion of the different loss functions. Args: (....): See docstring of function :func:`train_tasks`. t: Task id. i: Current training iteration. net_copy: Copy/checkpoint of the classifier network before learning task ``t``. Returns: The loss between predictions and predictions of a checkpointed network or replayed data. """ all_Y_hat_ls = [] all_targets = [] # we have to choose from which embeddings (multiple?!) to sample from if config.class_incremental or config.single_class_replay: # if we trained every class with a different generator emb_num = t * config.out_dim else: # here samples from the whole task come from one generator emb_num = t # we have to choose from which embeddings to sample from if config.fake_data_full_range: ran = range(0, emb_num) bs_per_task = int(np.ceil(config.batch_size / emb_num)) else: random_t = np.random.randint(0, emb_num) ran = range(random_t, random_t + 1) bs_per_task = config.batch_size # print('config.upper_bound: ',config.upper_bound) # print('config.num_embeddings: ',config.num_embeddings) for re in ran: # exchange replay data with real data to compute upper bounds if config.upper_bound: real_batch = dhandlers_rp[re].next_train_batch(bs_per_task) #15 X_fake = dhandlers_rp[re].input_to_torch_tensor( real_batch[0], device, mode='train') #each batch 128 else: # get fake data if config.replay_method == 'gan': X_fake = sample_gan(dec, d_hnet, config, re, device, bs=bs_per_task) else: X_fake = sample_vae(dec, d_hnet, config, re, device, bs=bs_per_task) # print('X_fake: ',X_fake.size()) # save some fake data to the writer # if i % 100 == 0: # if X_fake.shape[0] >= 15: # fig_fake = _plotImages(X_fake, config, bs_per_task) # writer.add_figure('train_class_' + str(re) + '_fake', # fig_fake, global_step=i) # compute soft targets with copied network target_logits = net_copy.forward(X_fake).detach() Y_hat_ls = net.forward(X_fake.detach()) ############### # BUILD TARGETS ############### if config.cl_scenario == 1: # take the task specific output neuron task_out = [sum(config.dims[:re]), sum(config.dims[:re + 1])] Y_hat_ls = Y_hat_ls[:, task_out[0]:task_out[1]] target_logits = target_logits[:, task_out[0]:task_out[1]] # build hard targets i.e. one hots if this option is chosen if config.hard_targets: soft_targets = torch.sigmoid(target_logits) zeros = torch.zeros(Y_hat_ls.shape).to(device) _, argmax = torch.max(soft_targets, 1) targets = zeros.scatter_(1, argmax.view(-1, 1), 1) else: # loss expects logits targets = target_logits # add to list all_targets.append(targets) all_Y_hat_ls.append(Y_hat_ls) # cat to one tensor # all_targets = torch.cat(all_targets) # Y_hat_ls = torch.cat(all_Y_hat_ls) all_targets = all_targets Y_hat_ls = all_Y_hat_ls if i % 200 == 0: classifier_accuracy = Classifier.accuracy(Y_hat_ls, all_targets) * 100.0 msg = 'Training step {}: Classifier Accuracy: {:.3f} ' + \ '(on current FAKE DATA training batch).' print(msg.format(i, classifier_accuracy)) # dependent on the target softness, the loss function is chosen if config.hard_targets or (config.class_incremental and t == 1): return Classifier.logit_cross_entropy_loss(Y_hat_ls, all_targets) else: return Classifier.knowledge_distillation_loss(Y_hat_ls, all_targets)