예제 #1
0
def validate(dataset, gen_model, disc_model, criterion, epoch, device, args,
             save_path_pictures):
    batch_time = AverageMeter()
    losses = AverageMeter()
    gen_losses = AverageMeter()
    disc_losses = AverageMeter()

    gen_model.eval()
    disc_model.eval()

    end = time.time()

    for i, (input, target) in enumerate(dataset.val_loader):
        with torch.no_grad():
            input, target = input.to(device), input.to(device)

            if args.no_gpus > 1:
                input_size = gen_model.module.input_size
            else:
                input_size = gen_model.input_size

            # set inputs and targets
            z = torch.randn((input.size(0), input_size)).to(device)
            y_real, y_fake = torch.ones(input.size(0),
                                        1).to(device), torch.zeros(
                                            input.size(0), 1).to(device)

            disc_real = disc_model(input)
            gen_out = gen_model(z)
            disc_fake = disc_model(gen_out)

            disc_real_loss = criterion(disc_real, y_real)
            disc_fake_loss = criterion(disc_fake, y_fake)
            disc_loss = disc_real_loss + disc_fake_loss

            disc_losses.update(disc_loss.item(), input.size(0))

            gen_out = gen_model(z)
            disc_fake = disc_model(gen_out)
            gen_loss = criterion(disc_fake, y_real)

            gen_losses.update(gen_loss.item(), input.size(0))

            if i % args.print_freq == 0:
                save_image((gen_out.data.view(-1, input.size(1), input.size(2),
                                              input.size(3))),
                           os.path.join(
                               save_path_pictures, 'sample_epoch_' +
                               str(epoch) + '_ite_' + str(i + 1) + '.png'))
            del input, target, z, y_real, y_fake, disc_real, gen_out, disc_fake

    print(' * Validate: Generator Loss {gen_losses.avg:.3f} Discriminator Loss {disc_losses.avg:.3f}'\
        .format(gen_losses=gen_losses, disc_losses=disc_losses))
    print('-' * 80)

    return disc_losses.avg, gen_losses.avg
예제 #2
0
def validate(Dataset, model, criterion, epoch, writer, device, save_path, args):
    """
    Evaluates/validates the model

    Parameters:
        Dataset (torch.utils.data.Dataset): The dataset
        model (torch.nn.module): Model to be evaluated/validated
        criterion (torch.nn.criterion): Loss function
        epoch (int): Epoch counter
        writer (tensorboard.SummaryWriter): TensorBoard writer instance
        device (str): device name where data is transferred to
        save_path (str): path to save data to
        args (dict): Dictionary of (command line) arguments.
            Needs to contain print_freq (int), epochs (int), incremental_data (bool), autoregression (bool),
            visualization_epoch (int), cross_dataset (bool), num_base_tasks (int), num_increment_tasks (int) and
            patch_size (int).

    Returns:
        float: top1 precision/accuracy
        float: average loss
    """

    # initialize average meters to accumulate values
    losses = AverageMeter()
    class_losses = AverageMeter()
    inos_losses = AverageMeter()
    batch_time = AverageMeter()
    top1 = AverageMeter()

    # confusion matrix
    confusion = ConfusionMeter(model.module.num_classes, normalized=True)

    # switch to evaluate mode
    model.eval()

    end = time.time()
    # evaluate the entire validation dataset
    with torch.no_grad():
        for i, (inp, target) in enumerate(Dataset.val_loader):
            inp = inp.to(device)
            target = target.to(device)

            recon_target = inp
            class_target = target[0]

            # compute output
            output, score = model(inp)

            # compute loss
            cl,rl = criterion(output, target, score, device, args)
            loss = cl + rl

            # measure accuracy, record loss, fill confusion matrix
            prec1 = accuracy(output, class_target)[0]
            top1.update(prec1.item(), inp.size(0))
            class_losses.update(cl.item(), inp.size(0))
            inos_losses.update(rl.item(), inp.size(0))
            confusion.add(output.data, target)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            losses.update(loss.item(), inp.size(0))

            # Print progress
            if i % args.print_freq == 0:
                print('Validate: [{0}][{1}/{2}]\t' 
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                       .format(
                       epoch+1, i, len(Dataset.val_loader), batch_time=batch_time, loss=losses, 
                       top1=top1))

    # TensorBoard summary logging
    writer.add_scalar('validation/precision@1', top1.avg, epoch)
    writer.add_scalar('validation/average_loss', losses.avg, epoch)
    writer.add_scalar('validation/class_loss',class_losses.avg, epoch)
    writer.add_scalar('validation/inos_loss', inos_losses.avg, epoch)

    print(' * Validation: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}'.format(loss=losses, top1=top1))

    # At the end of training isolated, or at the end of every task visualize the confusion matrix
    if (epoch + 1) % args.epochs == 0 and epoch > 0:
        # visualize the confusion matrix
        visualize_confusion(writer, epoch + 1, confusion.value(), Dataset.class_to_idx, save_path)

    return top1.avg, losses.avg
예제 #3
0
def train(train_loader, model, criterion, epoch, optimizer, lr_scheduler,
          device, args, split_batch_size):
    """
    trains the model of a net for one epoch on the train set
    
    Parameters:
        train_loader (torch.utils.data.DataLoader): data loader for the train set
        model (lib.Models.network.Net): model of the net to be trained
        criterion (torch.nn.BCELoss): loss criterion to be optimized
        epoch (int): continuous epoch counter
        optimizer (torch.optim.SGD): optimizer instance like SGD or Adam
        lr_scheduler (lib.Training.learning_rate_scheduling.LearningRateScheduler): class implementing learning rate
                                                                                    schedules
        device (torch.device): computational device (cpu or gpu)
        args (argparse.ArgumentParser): parsed command line arguments
        split_batch_size (int):  smaller batch size after splitting the original batch size for fitting the device
                                 memory
    """
    # performance and computational overhead metrics
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    hard_prec = AverageMeter()
    soft_prec = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()

    factor = args.batch_size // split_batch_size
    last_batch = int(
        math.ceil(len(train_loader.dataset) / float(split_batch_size)))

    optimizer.zero_grad()

    print('training')

    for i, (input_, target) in enumerate(train_loader):
        # hacky way to deal with terminal batch-size of 1
        if input_.size(0) == 1:
            print('skip last training batch of size 1')
            continue

        input_, target = input_.to(device), target.to(device)

        data_time.update(time.time() - end)

        # adjust learning rate after every 'factor' times 'batch count' (after every batch had the batch size not been
        # split)
        if i % factor == 0:
            lr_scheduler.adjust_learning_rate(optimizer, i // factor + 1)

        output = model(input_)

        # scale the loss by the ratio of the split batch size and the original
        loss = criterion(output, target) * input_.size(0) / float(
            args.batch_size)

        # update the 'losses' meter with the actual measure of the loss
        losses.update(loss.item() * args.batch_size / float(input_.size(0)),
                      input_.size(0))

        # compute performance measures
        output = output >= 0.5  # binarizing sigmoid output by thresholding with 0.5
        equality_matrix = (output.float() == target).float()
        hard = torch.mean(torch.prod(equality_matrix, dim=1)) * 100.
        soft = torch.mean(equality_matrix) * 100.

        # update peformance meters
        hard_prec.update(hard.item(), input_.size(0))
        soft_prec.update(soft.item(), input_.size(0))

        loss.backward()

        # update the weights after every 'factor' times 'batch count' (after every batch had the batch size not been
        # split)
        if (i + 1) % factor == 0 or i == (last_batch - 1):
            optimizer.step()
            optimizer.zero_grad()

        del output, input_, target
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print performance and computational overhead measures after every 'factor' times 'batch count' (after every
        # batch had the batch size not been split)
        if i % (args.print_freq * factor) == 0:
            print(
                'epoch: [{0}][{1}/{2}]\t'
                'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                'data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                'loss {losses.val:.3f} ({losses.avg:.3f})\t'
                'hard prec {hard_prec.val:.3f} ({hard_prec.avg:.3f})\t'
                'soft prec {soft_prec.val:.3f} ({soft_prec.avg:.3f})\t'.format(
                    epoch,
                    i,
                    len(train_loader),
                    batch_time=batch_time,
                    data_time=data_time,
                    losses=losses,
                    hard_prec=hard_prec,
                    soft_prec=soft_prec))

    lr_scheduler.scheduler_epoch += 1

    print(
        ' * train: loss {losses.avg:.3f} hard prec {hard_prec.avg:.3f} soft prec {soft_prec.avg:.3f}'
        .format(losses=losses, hard_prec=hard_prec, soft_prec=soft_prec))
    print('*' * 80)
예제 #4
0
def validate(Dataset, model, criterion, epoch, writer, device, save_path, args):
    """
    Evaluates/validates the model

    Parameters:
        Dataset (torch.utils.data.Dataset): The dataset
        model (torch.nn.module): Model to be evaluated/validated
        criterion (torch.nn.criterion): Loss function
        epoch (int): Epoch counter
        writer (tensorboard.SummaryWriter): TensorBoard writer instance
        device (str): device name where data is transferred to
        save_path (str): path to save data to
        args (dict): Dictionary of (command line) arguments.
            Needs to contain print_freq (int), epochs (int), incremental_data (bool), autoregression (bool),
            visualization_epoch (int), num_base_tasks (int), num_increment_tasks (int) and
            patch_size (int).

    Returns:
        float: top1 precision/accuracy
        float: average loss
    """

    # initialize average meters to accumulate values
    class_losses = AverageMeter()
    recon_losses_nat = AverageMeter()
    kld_losses = AverageMeter()
    losses = AverageMeter()

    # for autoregressive models add an additional instance for reconstruction loss in bits per dimension
    if args.autoregression:
        recon_losses_bits_per_dim = AverageMeter()

    # for continual learning settings also add instances for base and new reconstruction metrics
    # corresponding accuracy values are calculated directly from the confusion matrix below
    if args.incremental_data and ((epoch + 1) % args.epochs == 0 and epoch > 0):
        recon_losses_new_nat = AverageMeter()
        recon_losses_base_nat = AverageMeter()
        if args.autoregression:
            recon_losses_new_bits_per_dim = AverageMeter()
            recon_losses_base_bits_per_dim = AverageMeter()

    batch_time = AverageMeter()
    top1 = AverageMeter()

    # confusion matrix
    confusion = ConfusionMeter(model.module.num_classes, normalized=True)

    # switch to evaluate mode
    model.eval()

    end = time.time()

    # evaluate the entire validation dataset
    with torch.no_grad():
        for i, (inp, target) in enumerate(Dataset.val_loader):
            inp = inp.to(device)
            target = target.to(device)

            recon_target = inp
            class_target = target

            # compute output
            class_samples, recon_samples, mu, std = model(inp)

            # for autoregressive models convert the target to 0-255 integers and compute the autoregressive decoder
            # for each sample
            if args.autoregression:
                recon_target = (recon_target * 255).long()
                recon_samples_autoregression = torch.zeros(recon_samples.size(0), inp.size(0), 256, inp.size(1),
                                                           inp.size(2), inp.size(3)).to(device)
                for j in range(model.module.num_samples):
                    recon_samples_autoregression[j] = model.module.pixelcnn(
                        inp, torch.sigmoid(recon_samples[j])).contiguous()
                recon_samples = recon_samples_autoregression

            # compute loss
            class_loss, recon_loss, kld_loss = criterion(class_samples, class_target, recon_samples, recon_target, mu,
                                                         std, device, args)

            # For autoregressive models also update the bits per dimension value, converted from the obtained nats
            if args.autoregression:
                recon_losses_bits_per_dim.update(recon_loss.item() * math.log2(math.e), inp.size(0))

            # take mean to compute accuracy
            # (does nothing if there isn't more than 1 sample per input other than removing dummy dimension)
            class_output = torch.mean(class_samples, dim=0)
            recon_output = torch.mean(recon_samples, dim=0)

            # measure accuracy, record loss, fill confusion matrix
            prec1 = accuracy(class_output, target)[0]
            top1.update(prec1.item(), inp.size(0))
            confusion.add(class_output.data, target)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # for autoregressive models generate reconstructions by sequential sampling from the
            # multinomial distribution (Reminder: the original output is a 255 way Softmax as PixelVAEs are posed as a
            # classification problem). This serves two purposes: visualization of reconstructions and computation of
            # a reconstruction loss in nats using a BCE loss, comparable to that of a regular VAE.
            recon_target = inp
            if args.autoregression:
                recon = torch.zeros((inp.size(0), inp.size(1), inp.size(2), inp.size(3))).to(device)
                for h in range(inp.size(2)):
                    for w in range(inp.size(3)):
                        for c in range(inp.size(1)):
                            probs = torch.softmax(recon_output[:, :, c, h, w], dim=1).data
                            pixel_sample = torch.multinomial(probs, 1).float() / 255.
                            recon[:, c, h, w] = pixel_sample.squeeze()

                if (epoch % args.visualization_epoch == 0) and (i == (len(Dataset.val_loader) - 1)) and (epoch > 0):
                    visualize_image_grid(recon, writer, epoch + 1, 'reconstruction_snapshot', save_path)

                recon_loss = F.binary_cross_entropy(recon, recon_target)
            else:
                # If not autoregressive simply apply the Sigmoid and visualize
                recon = torch.sigmoid(recon_output)
                if (i == (len(Dataset.val_loader) - 1)) and (epoch % args.visualization_epoch == 0) and (epoch > 0):
                    visualize_image_grid(recon, writer, epoch + 1, 'reconstruction_snapshot', save_path)

            # update the respective loss values. To be consistent with values reported in the literature we scale
            # our normalized losses back to un-normalized values.
            # For the KLD this also means the reported loss is not scaled by beta, to allow for a fair comparison
            # across potential weighting terms.
            class_losses.update(class_loss.item() * model.module.num_classes, inp.size(0))
            kld_losses.update(kld_loss.item() * model.module.latent_dim, inp.size(0))
            recon_losses_nat.update(recon_loss.item() * inp.size()[1:].numel(), inp.size(0))
            losses.update((class_loss + recon_loss + kld_loss).item(), inp.size(0))

            # if we are learning continually, we need to calculate the base and new reconstruction losses at the end
            # of each task increment.
            if args.incremental_data and ((epoch + 1) % args.epochs == 0 and epoch > 0):
                for j in range(inp.size(0)):
                    # get the number of classes for class incremental scenarios.
                    base_classes = model.module.seen_tasks[:args.num_base_tasks + 1]
                    new_classes = model.module.seen_tasks[-args.num_increment_tasks:]

                    if args.autoregression:
                        rec = recon_output[j].view(1, recon_output.size(1), recon_output.size(2),
                                                   recon_output.size(3), recon_output.size(4))
                        rec_tar = recon_target[j].view(1, recon_target.size(1), recon_target.size(2),
                                                       recon_target.size(3))

                    # If the input belongs to one of the base classes also update base metrics
                    if class_target[j].item() in base_classes:
                        if args.autoregression:
                            recon_losses_base_bits_per_dim.update(F.cross_entropy(rec, (rec_tar * 255).long()) *
                                                                  math.log2(math.e), 1)
                        recon_losses_base_nat.update(F.binary_cross_entropy(recon[j], recon_target[j]), 1)
                    # if the input belongs to one of the new classes also update new metrics
                    elif class_target[j].item() in new_classes:
                        if args.autoregression:
                            recon_losses_new_bits_per_dim.update(F.cross_entropy(rec, (rec_tar * 255).long()) *
                                                                 math.log2(math.e), 1)
                        recon_losses_new_nat.update(F.binary_cross_entropy(recon[j], recon_target[j]), 1)

            # If we are at the end of validation, create one mini-batch of example generations. Only do this every
            # other epoch specified by visualization_epoch to avoid generation of lots of images and computationally
            # expensive calculations of the autoregressive model's generation.
            if i == (len(Dataset.val_loader) - 1) and epoch % args.visualization_epoch == 0 and (epoch > 0):
                # generation
                gen = model.module.generate()

                if args.autoregression:
                    gen = model.module.pixelcnn.generate(gen)
                visualize_image_grid(gen, writer, epoch + 1, 'generation_snapshot', save_path)

            # Print progress
            if i % args.print_freq == 0:
                print('Validate: [{0}][{1}/{2}]\t' 
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t'
                      'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format(
                       epoch+1, i, len(Dataset.val_loader), batch_time=batch_time, loss=losses, cl_loss=class_losses,
                       top1=top1, recon_loss=recon_losses_nat, KLD_loss=kld_losses))

    # TensorBoard summary logging
    writer.add_scalar('validation/val_precision@1', top1.avg, epoch)
    writer.add_scalar('validation/val_average_loss', losses.avg, epoch)
    writer.add_scalar('validation/val_class_loss', class_losses.avg, epoch)
    writer.add_scalar('validation/val_recon_loss_nat', recon_losses_nat.avg, epoch)
    writer.add_scalar('validation/val_KLD', kld_losses.avg, epoch)

    if args.autoregression:
        writer.add_scalar('validation/val_recon_loss_bits_per_dim', recon_losses_bits_per_dim.avg, epoch)

    print(' * Validation: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}'.format(loss=losses, top1=top1))

    # At the end of training isolated, or at the end of every task visualize the confusion matrix
    if (epoch + 1) % args.epochs == 0 and epoch > 0:
        # visualize the confusion matrix
        visualize_confusion(writer, epoch + 1, confusion.value(), Dataset.class_to_idx, save_path)

        # If we are in a continual learning scenario, also use the confusion matrix to extract base and new precision.
        if args.incremental_data:
            prec1_base = 0.0
            prec1_new = 0.0
            # this has to be + 1 because the number of initial tasks is always one less than the amount of classes
            # i.e. 1 task is 2 classes etc.
            for c in range(args.num_base_tasks + 1):
                prec1_base += confusion.value()[c][c]
            prec1_base = prec1_base / (args.num_base_tasks + 1)

            # For the first task "new" metrics are equivalent to "base"
            if (epoch + 1) / args.epochs == 1:
                prec1_new = prec1_base
                recon_losses_new_nat.avg = recon_losses_base_nat.avg
                if args.autoregression:
                    recon_losses_new_bits_per_dim.avg = recon_losses_base_bits_per_dim.avg
            else:
                for c in range(args.num_increment_tasks):
                    prec1_new += confusion.value()[-c-1][-c-1]
                prec1_new = prec1_new / args.num_increment_tasks

            # At the continual learning metrics to TensorBoard
            writer.add_scalar('validation/base_precision@1', prec1_base, len(model.module.seen_tasks)-1)
            writer.add_scalar('validation/new_precision@1', prec1_new, len(model.module.seen_tasks)-1)
            writer.add_scalar('validation/base_rec_loss_nats', recon_losses_base_nat.avg * args.patch_size *
                              args.patch_size * model.module.num_colors, len(model.module.seen_tasks) - 1)
            writer.add_scalar('validation/new_rec_loss_nats', recon_losses_new_nat.avg * args.patch_size *
                              args.patch_size * model.module.num_colors, len(model.module.seen_tasks) - 1)

            if args.autoregression:
                writer.add_scalar('validation/base_rec_loss_bits_per_dim',
                                  recon_losses_base_bits_per_dim.avg, len(model.module.seen_tasks) - 1)
                writer.add_scalar('validation/new_rec_loss_bits_per_dim',
                                  recon_losses_new_bits_per_dim.avg, len(model.module.seen_tasks) - 1)

            print(' * Incremental validation: Base Prec@1 {prec1_base:.3f} New Prec@1 {prec1_new:.3f}\t'
                  'Base Recon Loss {recon_losses_base_nat.avg:.3f} New Recon Loss {recon_losses_new_nat.avg:.3f}'
                  .format(prec1_base=100*prec1_base, prec1_new=100*prec1_new,
                          recon_losses_base_nat=recon_losses_base_nat, recon_losses_new_nat=recon_losses_new_nat))

    return top1.avg, losses.avg
예제 #5
0
def train_var(Dataset, model, criterion, epoch, optimizer, writer, device,
              args):
    """
    Trains/updates the model for one epoch on the training dataset.

    Parameters:
        Dataset (torch.utils.data.Dataset): The dataset
        model (torch.nn.module): Model to be trained
        criterion (torch.nn.criterion): Loss function
        epoch (int): Continuous epoch counter
        optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam
        writer (tensorboard.SummaryWriter): TensorBoard writer instance
        device (str): device name where data is transferred to
        args (dict): Dictionary of (command line) arguments.
            Needs to contain print_freq (int), denoising_noise_value (float) and var_beta (float).
    """

    # Create instances to accumulate losses etc.
    cl_losses = AverageMeter()
    kld_losses = AverageMeter()
    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    top1 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()

    # train
    for i, (inp, target) in enumerate(Dataset.train_loader):
        inp = inp.to(device)
        target = target.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute model forward
        output_samples, mu, std = model(inp)

        # calculate loss
        cl_loss, kld_loss = criterion(output_samples, target, mu, std, device)

        # add the individual loss components together and weight the KL term.
        loss = cl_loss + args.var_beta * kld_loss

        # take mean to compute accuracy. Note if variational samples are 1 this only gets rid of a dummy dimension.
        output = torch.mean(output_samples, dim=0)

        # record precision/accuracy and losses
        prec1 = accuracy(output, target)[0]
        top1.update(prec1.item(), inp.size(0))

        losses.update((cl_loss + kld_loss).item(), inp.size(0))
        cl_losses.update(cl_loss.item(), inp.size(0))
        kld_losses.update(kld_loss.item(), inp.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print progress
        if i % args.print_freq == 0:
            print('Training: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format(
                      epoch + 1,
                      i,
                      len(Dataset.train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      cl_loss=cl_losses,
                      top1=top1,
                      KLD_loss=kld_losses))

    # TensorBoard summary logging
    writer.add_scalar('training/train_precision@1', top1.avg, epoch)
    writer.add_scalar('training/train_class_loss', cl_losses.avg, epoch)
    writer.add_scalar('training/train_average_loss', losses.avg, epoch)
    writer.add_scalar('training/train_KLD', kld_losses.avg, epoch)

    print(' * Train: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}'.format(
        loss=losses, top1=top1))
예제 #6
0
def train(Dataset, model, criterion, epoch, optimizer, writer, device, args):
    """
    Trains/updates the model for one epoch on the training dataset.

    Parameters:
        Dataset (torch.utils.data.Dataset): The dataset
        model (torch.nn.module): Model to be trained
        criterion (torch.nn.criterion): Loss function
        epoch (int): Continuous epoch counter
        optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam
        writer (tensorboard.SummaryWriter): TensorBoard writer instance
        device (str): device name where data is transferred to
        args (dict): Dictionary of (command line) arguments.
            Needs to contain print_freq (int) and log_weights (bool).
    """

    # Create instances to accumulate losses etc.
    losses = AverageMeter()
    class_losses = AverageMeter()
    inos_losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    top1 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    # train
    for i, (inp, target) in enumerate(Dataset.train_loader):
        inp = inp.to(device)
        target = target.to(device)

        class_target = target[0]

        # measure data loading time
        data_time.update(time.time() - end)

        # compute model forward
        output, score = model(inp)

        # calculate loss
        cl, rl = criterion(output, target, score, device, args)
        loss = cl + rl

        # record precision/accuracy and losses
        prec1 = accuracy(output, class_target)[0]
        top1.update(prec1.item(), inp.size(0))
        class_losses.update(cl.item(), inp.size(0))
        inos_losses.update(rl.item(), inp.size(0))
        losses.update(loss.item(), inp.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print progress
        if i % args.print_freq == 0:
            print('Training: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'.format(
                      epoch + 1,
                      i,
                      len(Dataset.train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1))

    # TensorBoard summary logging
    writer.add_scalar('train/precision@1', top1.avg, epoch)
    writer.add_scalar('train/average_loss', losses.avg, epoch)
    writer.add_scalar('train/class_loss', class_losses.avg, epoch)
    writer.add_scalar('train/inos_loss', inos_losses.avg, epoch)

    # If the log weights argument is specified also add parameter and gradient histograms to TensorBoard.
    if args.log_weights:
        # Histograms and distributions of network parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            writer.add_histogram(tag,
                                 value.data.cpu().numpy(),
                                 epoch,
                                 bins="auto")
            # second check required for buffers that appear in the parameters dict but don't receive gradients
            if value.requires_grad and value.grad is not None:
                writer.add_histogram(tag + '/grad',
                                     value.grad.data.cpu().numpy(),
                                     epoch,
                                     bins="auto")

    print(' * Train: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}'.format(
        loss=losses, top1=top1))
def train(dataset, model, criterion, epoch, optimizer, lr_scheduler, device,
          args):
    """
    Trains/updates the model for one epoch on the training dataset.

    Parameters:
        train_loader (torch.utils.data.DataLoader): The trainset dataloader
        model (torch.nn.module): Model to be trained
        criterion (torch.nn.criterion): Loss function
        epoch (int): Continuous epoch counter
        optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam
        lr_scheduler (Training.LearningRateScheduler): class implementing learning rate schedules
        device (str): device name where data is transferred to
        args (dict): Dictionary of (command line) arguments.
            Needs to contain learning_rate (float), momentum (float),
            weight_decay (float), nesterov momentum (bool), lr_dropstep (int),
            lr_dropfactor (float), print_freq (int) and expand (bool).
    """

    batch_time = AverageMeter()
    data_time = AverageMeter()
    cl_losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()

    for i, (input, target) in enumerate(dataset.train_loader):
        input, target = input.to(device), target.to(device)
        # measure data loading time
        data_time.update(time.time() - end)

        # adjust the learning rate if applicable
        lr_scheduler.adjust_learning_rate(optimizer, i + 1)

        # compute output
        output = model(input)

        # making targets one-hot for using BCEloss
        target_temp = target
        one_hot = torch.zeros(target.size(0), output.size(1)).to(device)
        one_hot.scatter_(1, target.long().view(target.size(0), -1), 1)
        target = one_hot

        # compute loss and accuracy
        loss = criterion(output, target)
        prec1, prec5 = accuracy(output, target_temp, (1, 5))

        # measure accuracy and record loss
        cl_losses.update(loss.item(), input.size(0))
        prec1, prec5 = accuracy(output, target_temp, (1, 5))

        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        del output, input, target
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format(
                      epoch,
                      i,
                      len(dataset.train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=cl_losses,
                      top1=top1,
                      top5=top5))

    lr_scheduler.scheduler_epoch += 1

    print(' * Train: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(
        top1=top1, top5=top5))
    print('=' * 80)
예제 #8
0
def validate_var(Dataset, model, criterion, epoch, writer, device, args):
    """
    Evaluates/validates the model

    Parameters:
        Dataset (torch.utils.data.Dataset): The dataset
        model (torch.nn.module): Model to be evaluated/validated
        criterion (torch.nn.criterion): Loss function
        epoch (int): Epoch counter
        writer (tensorboard.SummaryWriter): TensorBoard writer instance
        device (str): device name where data is transferred to
        args (dict): Dictionary of (command line) arguments.
            Needs to contain print_freq (int), epochs (int) and patch_size (int).

    Returns:
        float: top1 precision/accuracy
        float: average loss
    """

    # initialize average meters to accumulate values
    cl_losses = AverageMeter()
    kld_losses = AverageMeter()
    losses = AverageMeter()
    batch_time = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()

    # evaluate the entire validation dataset
    with torch.no_grad():
        for i, (inp, target) in enumerate(Dataset.val_loader):
            inp = inp.to(device)
            target = target.to(device)

            # compute output
            output_samples, mu, std = model(inp)

            # compute loss
            cl_loss, kld_loss = criterion(output_samples, target, mu, std,
                                          device)

            # take mean to compute accuracy
            # (does nothing if there isn't more than 1 sample per input other than removing dummy dimension)
            output = torch.mean(output_samples, dim=0)

            # measure and update accuracy
            prec1 = accuracy(output, target)[0]
            top1.update(prec1.item(), inp.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # update the respective loss values. To be consistent with values reported in the literature we scale
            # our normalized losses back to un-normalized values.
            # For the KLD this also means the reported loss is not scaled by beta, to allow for a fair comparison
            # across potential weighting terms.
            cl_losses.update(cl_loss.item() * model.module.num_classes,
                             inp.size(0))
            kld_losses.update(kld_loss.item() * model.module.latent_dim,
                              inp.size(0))
            losses.update((cl_loss + kld_loss).item(), inp.size(0))

            # Print progress
            if i % args.print_freq == 0:
                print('Validate: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format(
                          epoch + 1,
                          i,
                          len(Dataset.val_loader),
                          batch_time=batch_time,
                          loss=losses,
                          cl_loss=cl_losses,
                          top1=top1,
                          KLD_loss=kld_losses))

    # TensorBoard summary logging
    writer.add_scalar('validation/val_precision@1', top1.avg, epoch)
    writer.add_scalar('validation/val_class_loss', cl_losses.avg, epoch)
    writer.add_scalar('validation/val_average_loss', losses.avg, epoch)
    writer.add_scalar('validation/val_KLD', kld_losses.avg, epoch)

    print(' * Validation: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}'.format(
        loss=losses, top1=top1))

    return top1.avg, losses.avg
예제 #9
0
def train(dataset, gen_model, disc_model, criterion, epoch, gen_optimizer,
          disc_optimizer, lr_scheduler, device, args):
    gen_losses = AverageMeter()
    disc_losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    gen_model.train()
    disc_model.train()

    end = time.time()

    for i, (input, target) in enumerate(dataset.train_loader):
        input, target = input.to(device), input.to(device)
        data_time.update(time.time() - end)

        lr_scheduler.adjust_learning_rate(gen_optimizer, i + 1)
        lr_scheduler.adjust_learning_rate(disc_optimizer, i + 1)

        if args.no_gpus > 1:
            input_size = gen_model.module.input_size
        else:
            input_size = gen_model.input_size

        # set inputs and targets
        z = torch.randn((input.size(0), input_size)).to(device)
        y_real, y_fake = torch.ones(input.size(0), 1).to(device), torch.zeros(
            input.size(0), 1).to(device)

        disc_real = disc_model(input)
        gen_out = gen_model(z)
        disc_fake = disc_model(gen_out)

        disc_real_loss = criterion(disc_real, y_real)
        disc_fake_loss = criterion(disc_fake, y_fake)
        disc_loss = disc_real_loss + disc_fake_loss

        disc_losses.update(disc_loss.item(), input.size(0))

        disc_optimizer.zero_grad()
        disc_loss.backward()
        disc_optimizer.step()

        gen_out = gen_model(z)
        disc_fake = disc_model(gen_out)
        gen_loss = criterion(disc_fake, y_real)

        gen_losses.update(gen_loss.item(), input.size(0))

        gen_optimizer.zero_grad()
        gen_loss.backward()
        gen_optimizer.step()
        del input, target, z, y_real, y_fake, disc_real, gen_out, disc_fake

        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            print(
                'Epoch: [{0}][{1}/{2}]\t'
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                'Generator Loss {gen_losses.val:.4f} ({gen_losses.avg:.4f})\t'
                'Discriminator Loss {disc_losses.val:.3f} ({disc_losses.avg:.3f})\t'
                .format(epoch,
                        i,
                        len(dataset.train_loader),
                        batch_time=batch_time,
                        data_time=data_time,
                        gen_losses=gen_losses,
                        disc_losses=disc_losses))

    print(' * Train: Generator Loss {gen_losses.avg:.3f} Discriminator Loss {disc_losses.avg:.3f}'\
        .format(gen_losses=gen_losses, disc_losses=disc_losses))
    print('-' * 80)
    return disc_losses.avg, gen_losses.avg
def validate(dataset, model, criterion, epoch, device, args):
    """
    Evaluates/validates the model

    Parameters:
        dataset (torch.utils.data.TensorDataset): The dataset
        model (torch.nn.module): Model to be evaluated/validated
        criterion (torch.nn.criterion): Loss function
        epoch (int): Epoch counter
        device (str): device name where data is transferred to
        args (dict): Dictionary of (command line) arguments.
            Needs to contain print_freq (int).

    Returns:
        float: top1 accuracy
    """
    cl_losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for i, (input, target) in enumerate(dataset.val_loader):
            input, target = input.to(device), target.to(device)

            # compute output
            output = model(input)

            # make targets one-hot for using BCEloss
            target_temp = target
            one_hot = torch.zeros(target.size(0), output.size(1)).to(device)
            one_hot.scatter_(1, target.long().view(target.size(0), -1), 1)
            target = one_hot

            # compute loss and accuracy
            loss = criterion(output, target)
            prec1, prec5 = accuracy(output, target_temp, (1, 5))

            # measure accuracy and record loss
            cl_losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

    #
    print(
        ' * Validation Task: \nLoss {loss.avg:.5f} Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
        .format(loss=cl_losses, top1=top1, top5=top5))
    print('=' * 80)

    return top1.avg
def validate(val_loader, model, criterion, epoch, device, args):
    """
    Evaluates/validates the model

    Parameters:
        val_loader (torch.utils.data.DataLoader): The validation or testset dataloader
        model (torch.nn.module): Model to be evaluated/validated
        criterion (torch.nn.criterion): Loss function
        epoch (int): Epoch counter
        device (str): device name where data is transferred to
        args (dict): Dictionary of (command line) arguments.
            Needs to contain print_freq (int).

    Returns:
        float: top1 accuracy
    """

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (inp, target) in enumerate(val_loader):
            inp, target = inp.to(device), target.to(device)

            # compute output
            output = model(inp)

            # compute loss
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), inp.size(0))
            top1.update(prec1.item(), inp.size(0))
            top5.update(prec5.item(), inp.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Validate: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                          epoch + 1,
                          i,
                          len(val_loader),
                          batch_time=batch_time,
                          loss=losses,
                          top1=top1,
                          top5=top5))

    print(' * Validation: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(
        top1=top1, top5=top5))

    return top1.avg
예제 #12
0
def val(val_loader, model, criterion, device, is_val=True):
    """
    validates the model of a net for one epoch on the validation set

    Parameters:
        val_loader (torch.utils.data.DataLoader): data loader for the validation set
        model (lib.Models.network.Net): model of a net that has to be validated
        criterion (torch.nn.BCELoss): loss criterion
        device (torch.device): device name where data is transferred to
        is_val (bool): validation or testing mode

    Returns:
        losses.avg (float): average of the validation losses over the batches
        hard_prec.avg (float): average of the validation hard precision over the batches for all the defects
        soft_prec.avg (float): average of the validation soft precision over the batches for all the defects
        hard_prec_background.avg (float): average of the validation hard/soft precision over the batches for background
        hard_prec_crack.avg (float): average of the validation hard/soft precision over the batches for crack
        hard_prec_spallation.avg (float): average of the validation hard/soft precision over the batches for spallation
        hard_prec_exposed_bars.avg (float): average of the validation hard/soft precision over the batches for exposed
                                            bars
        hard_prec_efflorescence.avg (float): average of the validation hard/soft precision over the batches for
                                             efflorescence
        hard_prec_corrosion_stain.avg (float): average of the validation hard/soft precision over the batches for
                                               corrosion stain
    """
    # performance metrics
    losses = AverageMeter()
    hard_prec = AverageMeter()
    soft_prec = AverageMeter()
    hard_prec_background = AverageMeter()
    hard_prec_crack = AverageMeter()
    hard_prec_spallation = AverageMeter()
    hard_prec_exposed_bars = AverageMeter()
    hard_prec_efflorescence = AverageMeter()
    hard_prec_corrosion_stain = AverageMeter()

    # for computing correct prediction percentage for multi-label samples and single-label samples
    number_multi_label_samples = number_correct_multi_label_predictions = number_single_label_samples =\
        number_correct_single_label_predictions = 0

    # switch to evaluate mode
    model.eval()

    if is_val:
        print('validating')
    else:
        print('testing')

    #  to ensure no buffering for gradient updates
    with torch.no_grad():
        for i, (input_, target) in enumerate(val_loader):
            if input_.size(0) == 1:
                # hacky way to deal with terminal batch-size of 1
                print('skip last val/test batch of size 1')
                continue
            input_, target = input_.to(device), target.to(device)

            output = model(input_)

            loss = criterion(output, target)

            # update the 'losses' meter
            losses.update(loss.item(), input_.size(0))

            # compute performance measures
            output = output >= 0.5  # binarizing sigmoid output by thresholding with 0.5
            temp_output = output
            temp_output = temp_output.float() + (((torch.sum(
                temp_output.float(), dim=1, keepdim=True) == 0).float()) * 0.5)

            # compute correct prediction percentage for multi-label/single-label samples
            sum_target_along_label_dimension = torch.sum(target,
                                                         dim=1,
                                                         keepdim=True)

            multi_label_samples = (sum_target_along_label_dimension >
                                   1).float()
            target_multi_label_samples = (target.float()) * multi_label_samples
            number_multi_label_samples += torch.sum(multi_label_samples).item()
            number_correct_multi_label_predictions += \
                torch.sum(torch.prod((temp_output.float() == target_multi_label_samples).float(), dim=1)).item()

            single_label_samples = (
                sum_target_along_label_dimension == 1).float()
            target_single_label_samples = (
                target.float()) * single_label_samples
            number_single_label_samples += torch.sum(
                single_label_samples).item()
            number_correct_single_label_predictions += \
                torch.sum(torch.prod((temp_output.float() == target_single_label_samples).float(), dim=1)).item()

            equality_matrix = (output.float() == target).float()
            hard = torch.mean(torch.prod(equality_matrix, dim=1)) * 100.
            soft = torch.mean(equality_matrix) * 100.
            hard_per_defect = torch.mean(equality_matrix, dim=0) * 100.

            # update peformance meters
            hard_prec.update(hard.item(), input_.size(0))
            soft_prec.update(soft.item(), input_.size(0))
            hard_prec_background.update(hard_per_defect[0].item(),
                                        input_.size(0))
            hard_prec_crack.update(hard_per_defect[1].item(), input_.size(0))
            hard_prec_spallation.update(hard_per_defect[2].item(),
                                        input_.size(0))
            hard_prec_exposed_bars.update(hard_per_defect[3].item(),
                                          input_.size(0))
            hard_prec_efflorescence.update(hard_per_defect[4].item(),
                                           input_.size(0))
            hard_prec_corrosion_stain.update(hard_per_defect[5].item(),
                                             input_.size(0))

    percentage_single_labels = (100. * number_correct_single_label_predictions
                                ) / number_single_label_samples
    percentage_multi_labels = (100. * number_correct_multi_label_predictions
                               ) / number_multi_label_samples

    if is_val:
        print(
            ' * val: loss {losses.avg:.3f}, hard prec {hard_prec.avg:.3f}, soft prec {soft_prec.avg:.3f},\t'
            '% correct single-label predictions {percentage_single_labels:.3f}, % correct multi-label predictions'
            '{percentage_multi_labels:.3f},\t'
            'hard prec background {hard_prec_background.avg:.3f}, hard prec crack {hard_prec_crack.avg:.3f},\t'
            'hard prec spallation {hard_prec_spallation.avg:.3f}, '
            'hard prec exposed bars {hard_prec_exposed_bars.avg:.3f},\t'
            'hard prec efflorescence {hard_prec_efflorescence.avg:.3f}, hard prec corrosion stain'
            ' {hard_prec_corrosion_stain.avg:.3f}\t'.format(
                losses=losses,
                hard_prec=hard_prec,
                soft_prec=soft_prec,
                percentage_single_labels=percentage_single_labels,
                percentage_multi_labels=percentage_multi_labels,
                hard_prec_background=hard_prec_background,
                hard_prec_crack=hard_prec_crack,
                hard_prec_spallation=hard_prec_spallation,
                hard_prec_exposed_bars=hard_prec_exposed_bars,
                hard_prec_efflorescence=hard_prec_efflorescence,
                hard_prec_corrosion_stain=hard_prec_corrosion_stain))
        print('*' * 80)
    else:
        print(
            ' * test: loss {losses.avg:.3f}, hard prec {hard_prec.avg:.3f}, soft prec {soft_prec.avg:.3f},\t'
            '% correct single-label predictions {percentage_single_labels:.3f}, % correct multi-label predictions'
            '{percentage_multi_labels:.3f},\t'
            'hard prec background {hard_prec_background.avg:.3f}, hard prec crack {hard_prec_crack.avg:.3f},\t'
            'hard prec spallation {hard_prec_spallation.avg:.3f}, '
            'hard prec exposed bars {hard_prec_exposed_bars.avg:.3f},\t'
            'hard prec efflorescence {hard_prec_efflorescence.avg:.3f}, hard prec corrosion stain'
            '{hard_prec_corrosion_stain.avg:.3f}\t'.format(
                losses=losses,
                hard_prec=hard_prec,
                soft_prec=soft_prec,
                percentage_single_labels=percentage_single_labels,
                percentage_multi_labels=percentage_multi_labels,
                hard_prec_background=hard_prec_background,
                hard_prec_crack=hard_prec_crack,
                hard_prec_spallation=hard_prec_spallation,
                hard_prec_exposed_bars=hard_prec_exposed_bars,
                hard_prec_efflorescence=hard_prec_efflorescence,
                hard_prec_corrosion_stain=hard_prec_corrosion_stain))
        print('*' * 80)

    if is_val:
        return losses.avg, hard_prec.avg, soft_prec.avg, hard_prec_background.avg, hard_prec_crack.avg,\
               hard_prec_spallation.avg, hard_prec_exposed_bars.avg, hard_prec_exposed_bars.avg,\
               hard_prec_corrosion_stain.avg
    else:
        return None
예제 #13
0
def train(Dataset, model, criterion, epoch, optimizer, writer, device, args):
    """
    Trains/updates the model for one epoch on the training dataset.

    Parameters:
        Dataset (torch.utils.data.Dataset): The dataset
        model (torch.nn.module): Model to be trained
        criterion (torch.nn.criterion): Loss function
        epoch (int): Continuous epoch counter
        optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam
        writer (tensorboard.SummaryWriter): TensorBoard writer instance
        device (str): device name where data is transferred to
        args (dict): Dictionary of (command line) arguments.
            Needs to contain print_freq (int) and log_weights (bool).
    """

    # Create instances to accumulate losses etc.
    class_losses = AverageMeter()
    recon_losses = AverageMeter()
    kld_losses = AverageMeter()
    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    top1 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()

    # train
    for i, (inp, target) in enumerate(Dataset.train_loader):
        inp = inp.to(device)
        target = target.to(device)

        recon_target = inp
        class_target = target

        # this needs to be below the line where the reconstruction target is set
        # sample and add noise to the input (but not to the target!).
        if args.denoising_noise_value > 0.0:
            noise = torch.randn(
                inp.size()).to(device) * args.denoising_noise_value
            inp = inp + noise

        # measure data loading time
        data_time.update(time.time() - end)

        # compute model forward
        class_samples, recon_samples, mu, std = model(inp)

        # if we have an autoregressive model variant, further calculate the corresponding layers.
        if args.autoregression:
            recon_samples_autoregression = torch.zeros(recon_samples.size(0),
                                                       inp.size(0), 256,
                                                       inp.size(1),
                                                       inp.size(2),
                                                       inp.size(3)).to(device)
            for j in range(model.module.num_samples):
                recon_samples_autoregression[j] = model.module.pixelcnn(
                    recon_target,
                    torch.sigmoid(recon_samples[j])).contiguous()
            recon_samples = recon_samples_autoregression
            # set the target to work with the 256-way Softmax
            recon_target = (recon_target * 255).long()

        # calculate loss
        class_loss, recon_loss, kld_loss = criterion(class_samples,
                                                     class_target,
                                                     recon_samples,
                                                     recon_target, mu, std,
                                                     device, args)

        # add the individual loss components together and weight the KL term.
        loss = class_loss + recon_loss + args.var_beta * kld_loss

        # take mean to compute accuracy. Note if variational samples are 1 this only gets rid of a dummy dimension.
        output = torch.mean(class_samples, dim=0)

        # record precision/accuracy and losses
        prec1 = accuracy(output, target)[0]
        top1.update(prec1.item(), inp.size(0))
        losses.update((class_loss + recon_loss + kld_loss).item(), inp.size(0))
        class_losses.update(class_loss.item(), inp.size(0))
        recon_losses.update(recon_loss.item(), inp.size(0))
        kld_losses.update(kld_loss.item(), inp.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print progress
        if i % args.print_freq == 0:
            print('Training: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t'
                  'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format(
                      epoch + 1,
                      i,
                      len(Dataset.train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      cl_loss=class_losses,
                      top1=top1,
                      recon_loss=recon_losses,
                      KLD_loss=kld_losses))

    # TensorBoard summary logging
    writer.add_scalar('training/train_precision@1', top1.avg, epoch)
    writer.add_scalar('training/train_average_loss', losses.avg, epoch)
    writer.add_scalar('training/train_KLD', kld_losses.avg, epoch)
    writer.add_scalar('training/train_class_loss', class_losses.avg, epoch)
    writer.add_scalar('training/train_recon_loss', recon_losses.avg, epoch)

    # If the log weights argument is specified also add parameter and gradient histograms to TensorBoard.
    if args.log_weights:
        # Histograms and distributions of network parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            writer.add_histogram(tag,
                                 value.data.cpu().numpy(),
                                 epoch,
                                 bins="auto")
            # second check required for buffers that appear in the parameters dict but don't receive gradients
            if value.requires_grad and value.grad is not None:
                writer.add_histogram(tag + '/grad',
                                     value.grad.data.cpu().numpy(),
                                     epoch,
                                     bins="auto")

    print(' * Train: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}'.format(
        loss=losses, top1=top1))
def train(dataset, model, criterion, epoch, optimizer, lr_scheduler, device, args):
    """
    Trains/updates the model for one epoch on the training dataset.

    Parameters:
        train_loader (torch.utils.data.DataLoader): The trainset dataloader
        model (torch.nn.module): Model to be trained
        criterion (torch.nn.criterion): Loss function
        epoch (int): Continuous epoch counter
        optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam
        lr_scheduler (Training.LearningRateScheduler): class implementing learning rate schedules
        device (str): device name where data is transferred to
        args (dict): Dictionary of (command line) arguments.
            Needs to contain learning_rate (float), momentum (float),
            weight_decay (float), nesterov momentum (bool), lr_dropstep (int),
            lr_dropfactor (float), print_freq (int) and expand (bool).
    """

    batch_time = AverageMeter()
    data_time = AverageMeter()
    elbo_losses = AverageMeter()
    rec_losses = AverageMeter()
    kl_losses = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()

    for i, (input, target) in enumerate(dataset.train_loader):
        input, target = input.to(device), input.to(device)
        # measure data loading time
        data_time.update(time.time() - end)

        # adjust the learning rate if applicable
        lr_scheduler.adjust_learning_rate(optimizer, i + 1)

        # compute output, mean and std
        output, mean, std = model(input)

        # compute loss 
        kl_loss = -0.5*torch.mean(1+torch.log(1e-8+std**2)-(mean**2)-(std**2))
        rec_loss = criterion(output, target)
        elbo_loss = rec_loss + kl_loss

        # record loss
        rec_losses.update(rec_loss.item(), input.size(0))
        kl_losses.update(kl_loss.item(), input.size(0))
        elbo_losses.update(elbo_loss.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        elbo_loss.backward()
        optimizer.step()
        del output, input, target
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'ELBO Loss {elbo_losses.val:.4f} ({elbo_losses.avg:.4f})\t'
                  'Reconstruction Loss {rec_losses.val:.3f} ({rec_losses.avg:.3f})\t'
                  'KL Divergence {kl_losses.val:.3f} ({kl_losses.avg:.3f})\t'.format(
                   epoch, i, len(dataset.train_loader), batch_time=batch_time,
                   data_time=data_time, elbo_losses=elbo_losses, rec_losses=rec_losses, kl_losses=kl_losses))

    lr_scheduler.scheduler_epoch += 1

    print(' * Train: ELBO Loss {elbo_losses.avg:.3f} Reconstruction Loss {rec_losses.avg:.3f} KL Divergence {kl_losses.avg:.3f}'\
        .format(elbo_losses=elbo_losses, rec_losses=rec_losses, kl_losses=kl_losses))
    print('-' * 80)
def validate(dataset, model, criterion, epoch, device, args,
             save_path_pictures):
    """
    Trains/updates the model for one epoch on the training dataset.

    Parameters:
        train_loader (torch.utils.data.DataLoader): The trainset dataloader
        model (torch.nn.module): Model to be trained
        criterion (torch.nn.criterion): Loss function
        epoch (int): Continuous epoch counter
        device (str): device name where data is transferred to
        args (dict): Dictionary of (command line) arguments.
            Needs to contain learning_rate (float), momentum (float),
            weight_decay (float), nesterov momentum (bool), lr_dropstep (int),
            lr_dropfactor (float), print_freq (int) and expand (bool).
    """

    elbo_losses = AverageMeter()
    rec_losses = AverageMeter()
    kl_losses = AverageMeter()

    # switch to train mode
    model.eval()

    for i, (input, target) in enumerate(dataset.train_loader):
        with torch.no_grad():
            input, target = input.to(device), input.to(device)

            # compute output, mean and std
            output, mean, std = model(input)

            # compute loss
            kl_loss = -0.5 * torch.mean(1 + torch.log(1e-8 + std**2) -
                                        (mean**2) - (std**2))
            rec_loss = criterion(output, target)
            elbo_loss = rec_loss + kl_loss

            # record loss
            rec_losses.update(rec_loss.item(), input.size(0))
            kl_losses.update(kl_loss.item(), input.size(0))
            elbo_losses.update(elbo_loss.item(), input.size(0))

            if i % args.print_freq == 0:
                save_image((input.data).view(-1, input.size(1), input.size(2),
                                             input.size(3)),
                           os.path.join(
                               save_path_pictures, 'input_epoch_' +
                               str(epoch) + '_ite_' + str(i + 1) + '.png'))
                save_image(
                    (output.data).view(-1, output.size(1), output.size(2),
                                       output.size(3)),
                    os.path.join(
                        save_path_pictures,
                        'epoch_' + str(epoch) + '_ite_' + str(i + 1) + '.png'))
                if args.no_gpus > 1:
                    sample = torch.randn(input.size(0),
                                         model.module.latent_size).to(device)
                    sample = model.module.sample(sample).cpu()
                else:
                    sample = torch.randn(input.size(0),
                                         model.latent_size).to(device)
                    sample = model.sample(sample).cpu()
                save_image((sample.view(-1, input.size(1), input.size(2),
                                        input.size(3))),
                           os.path.join(
                               save_path_pictures, 'sample_epoch_' +
                               str(epoch) + '_ite_' + str(i + 1) + '.png'))
            del output, input, target

    print(' * Validate: ELBO Loss {elbo_losses.avg:.3f} Reconstruction Loss {rec_losses.avg:.3f} KL Divergence {kl_losses.avg:.3f}'\
        .format(elbo_losses=elbo_losses, rec_losses=rec_losses, kl_losses=kl_losses))
    print('-' * 80)
    return 1. / elbo_losses.avg
예제 #16
0
def train(Dataset, model, criterion, epoch, optimizer, writer, device,
          save_path, args):
    """
    Trains/updates the model for one epoch on the training dataset.

    Parameters:
        Dataset (torch.utils.data.Dataset): The dataset
        model (torch.nn.module): Model to be trained
        criterion (torch.nn.criterion): Loss function
        epoch (int): Continuous epoch counter
        optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam
        writer (tensorboard.SummaryWriter): TensorBoard writer instance
        device (str): device name where data is transferred to
        args (dict): Dictionary of (command line) arguments.
            Needs to contain print_freq (int) and log_weights (bool).
    """

    # Create instances to accumulate losses etc.
    class_losses = AverageMeter()
    recon_losses = AverageMeter()
    kld_losses = AverageMeter()
    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    G_losses = AverageMeter()
    gp_losses = AverageMeter()
    D_losses = AverageMeter()
    D_losses_real = AverageMeter()
    D_losses_fake = AverageMeter()
    G_losses_fake = AverageMeter()
    fake_class_losses = AverageMeter()

    top1 = AverageMeter()

    # switch to train mode
    model.train()
    GAN_criterion = torch.nn.BCEWithLogitsLoss()

    end = time.time()

    # train
    for i, (inp, target) in enumerate(Dataset.train_loader):
        inp = inp.to(device)
        target = target.to(device)

        recon_target = inp
        class_target = target

        # this needs to be below the line where the reconstruction target is set
        # sample and add noise to the input (but not to the target!).
        if args.denoising_noise_value > 0.0:
            noise = torch.randn(
                inp.size()).to(device) * args.denoising_noise_value
            inp = inp + noise

        if args.blur:
            inp = blur_data(inp, args.patch_size, device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Model explanation: Conventionally GAN architecutre update D first and G
        #### D Update####
        class_samples, recon_samples, mu, std = model(inp)
        mu_label = None
        if args.proj_gan:
            # pred_label = torch.argmax(class_samples, dim=-1).squeeze()
            # mu_label = pred_label.to(device)
            mu_label = target

        real_z = model.module.forward_D(recon_target, mu_label)

        D_loss_real = torch.mean(torch.nn.functional.relu(1. - real_z))  #hinge
        # D_loss_real = - torch.mean(real_z)                                  #WGAN-GP
        D_losses_real.update(D_loss_real.item(), 1)

        n, b, c, x, y = recon_samples.shape
        fake_z = model.module.forward_D(
            (recon_samples.view(n * b, c, x, y)).detach(), mu_label)

        D_loss_fake = torch.mean(torch.nn.functional.relu(1. + fake_z))  #hinge
        # D_loss_fake = torch.mean(fake_z)                                #WGAN-GP
        D_losses_fake.update(D_loss_fake.item(), 1)

        GAN_D_loss = (D_loss_real + D_loss_fake)
        D_losses.update(GAN_D_loss.item(), inp.size(0))

        # Compute loss for gradient penalty
        alpha = torch.rand(recon_target.size(0), 1, 1, 1).to(device)
        x_hat = (alpha * recon_target + (1 - alpha) *
                 recon_samples.view(n * b, c, x, y)).requires_grad_(True)
        out_x_hat = model.module.forward_D(x_hat, mu_label)
        D_loss_gp = model.module.discriminator.gradient_penalty(
            out_x_hat, x_hat)
        gp_losses.update(D_loss_gp, inp.size(0))

        GAN_D_loss += args.lambda_gp * D_loss_gp

        # compute gradient and do SGD step
        optimizer['enc'].zero_grad()
        optimizer['dec'].zero_grad()
        optimizer['disc'].zero_grad()
        GAN_D_loss.backward()
        optimizer['disc'].step()

        #### G Update####
        if i % 1 == 0:
            class_samples, recon_samples, mu, std = model(inp)
            mu_label = None
            if args.proj_gan:
                # pred_label = torch.argmax(class_samples, dim=-1).squeeze()
                # mu_label = pred_label.to(device)
                mu_label = target

            # OCDVAE calculate loss
            class_loss, recon_loss, kld_loss = criterion(
                class_samples, class_target, recon_samples, recon_target, mu,
                std, device, args)
            # add the individual loss components together and weight the KL term.
            loss = class_loss + args.l1_weight * recon_loss + args.var_beta * kld_loss

            output = torch.mean(class_samples, dim=0)
            # record precision/accuracy and losses
            prec1 = accuracy(output, target)[0]
            top1.update(prec1.item(), inp.size(0))

            losses.update(loss.item(), inp.size(0))
            class_losses.update(class_loss.item(), inp.size(0))
            recon_losses.update(recon_loss.item(), inp.size(0))
            kld_losses.update(kld_loss.item(), inp.size(0))

            # Needed to add GAN_criterion on KL
            n, b, c, x, y = recon_samples.shape
            fake_z = model.module.forward_D(
                (recon_samples.view(n * b, c, x, y)), mu_label)

            GAN_G_loss = -torch.mean(fake_z)
            G_losses_fake.update(GAN_G_loss.item(), inp.size(0))

            GAN_G_loss = -args.var_gan_weight * torch.mean(fake_z)
            G_losses.update(GAN_G_loss.item(), inp.size(0))

            GAN_G_loss += loss

            optimizer['enc'].zero_grad()
            optimizer['dec'].zero_grad()
            optimizer['disc'].zero_grad()
            GAN_G_loss.backward()
            optimizer['enc'].step()
            optimizer['dec'].step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print progress
        if i % args.print_freq == 0:
            print("OCD_VAELoss: ")
            print('Training: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t'
                  'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format(
                      epoch + 1,
                      i,
                      len(Dataset.train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      cl_loss=class_losses,
                      top1=top1,
                      recon_loss=recon_losses,
                      KLD_loss=kld_losses))
            print("GANLoss: ")
            print('G Loss {G_loss.val:.4f} ({G_loss.avg:.4f})\t'
                  'D Loss {D_loss.val:.4f} ({D_loss.avg:.4f})'.format(
                      G_loss=G_losses, D_loss=D_losses))

        if (i == (len(Dataset.train_loader) -
                  2)) and (epoch % args.visualization_epoch == 0):
            visualize_image_grid(inp, writer, epoch + 1,
                                 'train_input_snapshot', save_path)
            visualize_image_grid(recon_samples.view(n * b, c, x,
                                                    y), writer, epoch + 1,
                                 'train_reconstruction_snapshot', save_path)

    # TensorBoard summary logging
    writer.add_scalar('training/train_precision@1', top1.avg, epoch)
    writer.add_scalar('training/train_average_loss', losses.avg, epoch)
    writer.add_scalar('training/train_KLD', kld_losses.avg, epoch)
    writer.add_scalar('training/train_class_loss', class_losses.avg, epoch)
    writer.add_scalar('training/train_recon_loss', recon_losses.avg, epoch)
    writer.add_scalar('training/train_G_loss', G_losses.avg, epoch)
    writer.add_scalar('training/train_D_loss', D_losses.avg, epoch)
    writer.add_scalar('training/train_D_loss_real', D_losses_real.avg, epoch)
    writer.add_scalar('training/train_D_loss_fake', D_losses_fake.avg, epoch)
    writer.add_scalar('training/train_G_loss_fake', G_losses_fake.avg, epoch)

    # If the log weights argument is specified also add parameter and gradient histograms to TensorBoard.
    if args.log_weights:
        # Histograms and distributions of network parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            writer.add_histogram(tag,
                                 value.data.cpu().numpy(),
                                 epoch,
                                 bins="auto")
            # second check required for buffers that appear in the parameters dict but don't receive gradients
            if value.requires_grad and value.grad is not None:
                writer.add_histogram(tag + '/grad',
                                     value.grad.data.cpu().numpy(),
                                     epoch,
                                     bins="auto")

    print(
        ' * Train: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}\t'
        ' GAN: Generator {G_losses.avg:.5f} Discriminator {D_losses.avg:.4f}'.
        format(loss=losses, top1=top1, G_losses=G_losses, D_losses=D_losses))
예제 #17
0
def train(Dataset, model, criterion, epoch, optimizer, writer, device, args):
    """
    Trains/updates the model for one epoch on the training dataset.

    Parameters:
        Dataset (torch.utils.data.Dataset): The dataset
        model (torch.nn.module): Model to be trained
        criterion (torch.nn.criterion): Loss function
        epoch (int): Continuous epoch counter
        optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam
        writer (tensorboard.SummaryWriter): TensorBoard writer instance
        device (str): device name where data is transferred to
        args (dict): Dictionary of (command line) arguments.
            Needs to contain print_freq (int) and log_weights (bool).
    """

    # Create instances to accumulate losses etc.
    class_losses = AverageMeter()
    recon_losses = AverageMeter()

    if args.introspection:
        kld_real_losses = AverageMeter()
        kld_fake_losses = AverageMeter()
        kld_rec_losses = AverageMeter()

        criterion_enc = criterion[0]
        criterion_dec = criterion[1]

        optimizer_enc = optimizer[0]
        optimizer_dec = optimizer[1]
    else:
        kld_losses = AverageMeter()

    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    top1 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()

    # train
    for i, (inp, target) in enumerate(Dataset.train_loader):
        if args.data_augmentation:
            inp = augmentation.augment_data(inp, args)

        inp = inp.to(device)
        target = target.to(device)

        recon_target = inp
        class_target = target

        # this needs to be below the line where the reconstruction target is set
        # sample and add noise to the input (but not to the target!).
        if args.denoising_noise_value > 0.0:
            noise = torch.randn(inp.size()).to(device) * args.denoising_noise_value
            inp = inp + noise

        if args.blur:
            inp = blur_data(inp, args.patch_size, device)

        # measure data loading time
        data_time.update(time.time() - end)

        if args.introspection:
            # Update encoder
            real_mu, real_std = model.module.encode(inp)
            z = model.module.reparameterize(real_mu, real_std)
            class_output = model.module.classifier(z)
            recon = torch.sigmoid(model.module.decode(z))

            model.eval()
            recon_mu, recon_std = model.module.encode(recon.detach())

            model.train()
            z_p = torch.randn(inp.size(0), model.module.latent_dim).to(model.module.device)
            recon_p = torch.sigmoid(model.module.decode(z_p))

            model.eval()
            mu_p, std_p = model.module.encode(recon_p.detach())

            model.train()

            cl, rl, kld_real, kld_rec, kld_fake = criterion_enc(class_output, class_target, recon, recon_target,
                                                                real_mu, real_std, recon_mu, recon_std, mu_p, std_p,
                                                                device, args)

            kld_real_losses.update(kld_real.item(), inp.size(0))

            alpha = 1
            if not args.gray_scale:
                alpha = 3
            loss_encoder = cl + alpha * rl + args.var_beta * (kld_real + 0.5 * (kld_rec + kld_fake) * args.gamma)

            optimizer_enc.zero_grad()
            loss_encoder.backward()
            optimizer_enc.step()

            # update decoder
            recon = torch.sigmoid(model.module.decode(z.detach()))
            model.eval()
            recon_mu, recon_std = model.module.encode(recon.detach())

            model.train()
            recon_p = torch.sigmoid(model.module.decode(z_p))

            model.eval()
            fake_mu, fake_std = model.module.encode(recon_p.detach())

            model.train()
            rl, kld_rec, kld_fake = criterion_dec(recon, recon_target, recon_mu, recon_std, fake_mu, fake_std, args)
            loss_decoder = 0.5 * (kld_rec + kld_fake) * args.gamma + rl * alpha

            optimizer_dec.zero_grad()
            loss_decoder.backward()
            optimizer_dec.step()

            losses.update((loss_encoder + loss_decoder).item(), inp.size(0))
            class_losses.update(cl.item(), inp.size(0))
            recon_losses.update(rl.item(), inp.size(0))
            kld_rec_losses.update(kld_rec.item(), inp.size(0))
            kld_fake_losses.update(kld_fake.item(), inp.size(0))

            # record precision/accuracy and losses
            prec1 = accuracy(class_output, target)[0]
            top1.update(prec1.item(), inp.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # print progress
            if i % args.print_freq == 0:
                print('Training: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t'
                      'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format(
                    epoch + 1, i, len(Dataset.train_loader), batch_time=batch_time,
                    data_time=data_time, loss=losses, cl_loss=class_losses, top1=top1,
                    recon_loss=recon_losses, KLD_loss=kld_real_losses))
        else:
            # compute model forward
            class_samples, recon_samples, mu, std = model(inp)

            # if we have an autoregressive model variant, further calculate the corresponding layers.
            if args.autoregression:
                recon_samples_autoregression = torch.zeros(recon_samples.size(0), inp.size(0), 256, inp.size(1),
                                                           inp.size(2), inp.size(3)).to(device)
                for j in range(model.module.num_samples):
                    recon_samples_autoregression[j] = model.module.pixelcnn(recon_target,
                                                                            torch.sigmoid(recon_samples[j])).contiguous()
                recon_samples = recon_samples_autoregression
                # set the target to work with the 256-way Softmax
                recon_target = (recon_target * 255).long()

            # calculate loss
            class_loss, recon_loss, kld_loss = criterion(class_samples, class_target, recon_samples, recon_target, mu, std,
                                                         device, args)

            # add the individual loss components together and weight the KL term.
            loss = class_loss + recon_loss + args.var_beta * kld_loss

            # take mean to compute accuracy. Note if variational samples are 1 this only gets rid of a dummy dimension.
            class_output = torch.mean(class_samples, dim=0)

            # record precision/accuracy and losses
            losses.update((class_loss + recon_loss + kld_loss).item(), inp.size(0))
            class_losses.update(class_loss.item(), inp.size(0))
            recon_losses.update(recon_loss.item(), inp.size(0))
            kld_losses.update(kld_loss.item(), inp.size(0))

            prec1 = accuracy(class_output, target)[0]
            top1.update(prec1.item(), inp.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # print progress
            if i % args.print_freq == 0:
                print('Training: [{0}][{1}/{2}]\t' 
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t'
                      'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format(
                       epoch+1, i, len(Dataset.train_loader), batch_time=batch_time,
                       data_time=data_time, loss=losses, cl_loss=class_losses, top1=top1,
                       recon_loss=recon_losses, KLD_loss=kld_losses))

    # TensorBoard summary logging
    if args.introspection:
        writer.add_scalar('training/train_precision@1', top1.avg, epoch)
        writer.add_scalar('training/train_average_loss', losses.avg, epoch)
        writer.add_scalar('training/train_KLD_real', kld_real_losses.avg, epoch)
        writer.add_scalar('training/train_class_loss', class_losses.avg, epoch)
        writer.add_scalar('training/train_KLD_rec', kld_rec_losses.avg, epoch)
        writer.add_scalar('training/train_KLD_fake', kld_fake_losses.avg, epoch)
    else:
        writer.add_scalar('training/train_precision@1', top1.avg, epoch)
        writer.add_scalar('training/train_average_loss', losses.avg, epoch)
        writer.add_scalar('training/train_KLD', kld_losses.avg, epoch)
        writer.add_scalar('training/train_class_loss', class_losses.avg, epoch)
        writer.add_scalar('training/train_recon_loss', recon_losses.avg, epoch)

    print(' * Train: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}'.format(loss=losses, top1=top1))
def train(train_loader, model, criterion, epoch, optimizer, lr_scheduler,
          device, batch_split_size, args):
    """
    Trains/updates the model for one epoch on the training dataset.

    Parameters:
        train_loader (torch.utils.data.DataLoader): The trainset dataloader
        model (torch.nn.module): Model to be trained
        criterion (torch.nn.criterion): Loss function
        epoch (int): Continuous epoch counter
        optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam
        lr_scheduler (Training.LearningRateScheduler): class implementing learning rate schedules
        device (str): device name where data is transferred to
        batch_split_size (int): size of smaller split of batch to
            calculate sequentially if too little memory is available
        args (dict): Dictionary of (command line) arguments.
            Needs to contain print_freq (int) and batch_size (int).
    """

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()

    factor = args.batch_size // batch_split_size
    last_batch = int(
        math.ceil(len(train_loader.dataset) / float(batch_split_size)))
    optimizer.zero_grad()

    for i, (inp, target) in enumerate(train_loader):
        inp, target = inp.to(device), target.to(device)
        # measure data loading time
        data_time.update(time.time() - end)

        # adjust the learning rate
        if i % factor == 0:
            lr_scheduler.adjust_learning_rate(optimizer, i // factor + 1)

        # compute output
        output = model(inp)

        loss = criterion(output, target) * inp.size(0) / float(args.batch_size)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item() * float(args.batch_size) / inp.size(0),
                      inp.size(0))
        top1.update(prec1.item(), inp.size(0))
        top5.update(prec5.item(), inp.size(0))

        # compute gradient and do SGD step
        loss.backward()

        if (i + 1) % factor == 0 or i == (last_batch - 1):
            optimizer.step()
            optimizer.zero_grad()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % (args.print_freq * factor) == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch + 1,
                      i,
                      len(train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))

    print(' * Train: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(
        top1=top1, top5=top5))

    return top1.avg
예제 #19
0
def train(Dataset, model, criterion, epoch, iteration, optimizer, writer,
          device, args, save_path):
    """
    Trains/updates the model for one epoch on the training dataset.

    Parameters:
        Dataset (torch.utils.data.Dataset): The dataset
        model (torch.nn.module): Model to be trained
        criterion (torch.nn.criterion): Loss function
        epoch (int): Continuous epoch counter
        optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam
        writer (tensorboard.SummaryWriter): TensorBoard writer instance
        device (str): device name where data is transferred to
        args (dict): Dictionary of (command line) arguments.
            Needs to contain print_freq (int) and log_weights (bool).
    """

    # Create instances to accumulate losses etc.
    class_losses = AverageMeter()
    recon_losses = AverageMeter()
    kld_losses = AverageMeter()
    losses = AverageMeter()
    lwf_losses = AverageMeter()
    si_losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    top1 = AverageMeter()

    # switch to train mode
    model.train()

    if args.use_si and not args.is_multiheaded:
        if not model.module.temp_classifier_weights is None:
            # load unconsolidated classifier weights
            un_consolidate_classifier(model.module)
            #print("SI: loaded unconsolidated classifier weights")
            #print("requires grad check: ", model.module.classifier[-1].weight)

    end = time.time()

    # train
    if args.is_multiheaded and args.train_incremental_upper_bound:
        for trainset_index in range(len(Dataset.mh_trainsets)):
            # get temporary  trainset loader according to trainset_index
            train_loader = torch.utils.data.DataLoader(
                Dataset.mh_trainsets[trainset_index],
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=torch.cuda.is_available(),
                drop_last=True)

            for i, (inp, target) in enumerate(train_loader):
                if args.is_multiheaded and not args.incremental_instance:
                    # multiheaded incremental classes: move targets to head space
                    target = target.clone()
                    for i in range(target.size(0)):
                        target[i] = Dataset.maps_target_head[trainset_index][
                            target.numpy()[i]]

                # move data to device
                inp = inp.to(device)
                target = target.to(device)

                if epoch % args.epochs == 0 and i == 0:
                    visualize_image_grid(inp, writer, epoch + 1,
                                         'train_inp_snapshot', save_path)

                recon_target = inp
                class_target = target

                # this needs to be below the line where the reconstruction target is set
                # sample and add noise to the input (but not to the target!).
                if args.denoising_noise_value > 0.0 and not args.no_vae:
                    noise = torch.randn(
                        inp.size()).to(device) * args.denoising_noise_value
                    inp = inp + noise

                # measure data loading time
                data_time.update(time.time() - end)

                # compute model forward
                class_samples, recon_samples, mu, std = model(inp)

                # if we have an autoregressive model variant, further calculate the corresponding layers.
                if args.autoregression:
                    recon_samples_autoregression = torch.zeros(
                        recon_samples.size(0), inp.size(0), 256, inp.size(1),
                        inp.size(2), inp.size(3)).to(device)
                    for j in range(model.module.num_samples):
                        recon_samples_autoregression[
                            j] = model.module.pixelcnn(
                                recon_target,
                                torch.sigmoid(recon_samples[j])).contiguous()
                    recon_samples = recon_samples_autoregression
                    # set the target to work with the 256-way Softmax
                    recon_target = (recon_target * 255).long()

                # computer loss for respective head
                if not args.incremental_instance:
                    head_start = (trainset_index) * args.num_increment_tasks
                    head_end = (trainset_index + 1) * args.num_increment_tasks
                else:
                    head_start = (trainset_index) * Dataset.num_classes
                    head_end = (trainset_index + 1) * Dataset.num_classes

                if args.is_multiheaded:
                    if not args.incremental_instance:
                        class_loss, recon_loss, kld_loss = criterion(
                            class_samples[:, :,
                                          head_start:head_end], class_target,
                            recon_samples, recon_target, mu, std, device, args)
                    else:
                        class_loss, recon_loss, kld_loss = criterion(
                            class_samples[:, :,
                                          head_start:head_end], class_target,
                            recon_samples, recon_target, mu, std, device, args)
                else:
                    if not args.is_segmentation:
                        class_loss, recon_loss, kld_loss = criterion(
                            class_samples, class_target, recon_samples,
                            recon_target, mu, std, device, args)
                    else:
                        class_loss, recon_loss, kld_loss = criterion(
                            class_samples,
                            class_target,
                            recon_samples,
                            recon_target,
                            mu,
                            std,
                            device,
                            args,
                            weight=Dataset.class_pixel_weight)

                # add the individual loss components together and weight the KL term.
                if args.no_vae:
                    loss = class_loss
                else:
                    loss = class_loss + recon_loss + args.var_beta * kld_loss

                # take mean to compute accuracy. Note if variational samples are 1 this only gets rid of a dummy dimension.
                if args.is_multiheaded:
                    if not args.incremental_instance:
                        output = torch.mean(class_samples[:, :,
                                                          head_start:head_end],
                                            dim=0)
                    else:
                        output = torch.mean(class_samples[:, :,
                                                          head_start:head_end],
                                            dim=0)
                else:
                    output = torch.mean(class_samples, dim=0)

                # record precision/accuracy and losses
                prec1 = accuracy(output, target)[0]
                top1.update(prec1.item(), inp.size(0))
                if args.no_vae:
                    losses.update(class_loss.item(), inp.size(0))
                    class_losses.update(class_loss.item(), inp.size(0))
                else:
                    losses.update((class_loss + recon_loss + kld_loss).item(),
                                  inp.size(0))
                    class_losses.update(class_loss.item(), inp.size(0))
                    recon_losses.update(recon_loss.item(), inp.size(0))
                    kld_losses.update(kld_loss.item(), inp.size(0))

                # compute gradient and do SGD step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # print progress
                if i % args.print_freq == 0:
                    if args.use_lwf and model.module.prev_model:
                        print(
                            'Training: [{0}][{1}/{2}]\t'
                            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                            'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                            'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                            'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t'
                            'LwF Loss {lwf_loss:.4f}\t'
                            'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                            'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t'
                            'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.
                            format(epoch,
                                   i,
                                   len(Dataset.train_loader),
                                   batch_time=batch_time,
                                   data_time=data_time,
                                   loss=losses,
                                   cl_loss=class_losses,
                                   top1=top1,
                                   recon_loss=recon_losses,
                                   KLD_loss=kld_losses,
                                   lwf_loss=cl_lwf.item()))
                    if args.use_si and model.module.si_storage.is_initialized:
                        print(
                            'Training: [{0}][{1}/{2}]\t'
                            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                            'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                            'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                            'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t'
                            'SI Loss {si_loss:.4f}\t'
                            'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                            'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t'
                            'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.
                            format(epoch,
                                   i,
                                   len(Dataset.train_loader),
                                   batch_time=batch_time,
                                   data_time=data_time,
                                   loss=losses,
                                   cl_loss=class_losses,
                                   top1=top1,
                                   recon_loss=recon_losses,
                                   KLD_loss=kld_losses,
                                   si_loss=loss_si.item()))

                    else:
                        print(
                            'Training: [{0}][{1}/{2}]\t'
                            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                            'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                            'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                            'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t'
                            'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                            'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t'
                            'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.
                            format(epoch,
                                   i,
                                   len(Dataset.train_loader),
                                   batch_time=batch_time,
                                   data_time=data_time,
                                   loss=losses,
                                   cl_loss=class_losses,
                                   top1=top1,
                                   recon_loss=recon_losses,
                                   KLD_loss=kld_losses))

                # increase iteration
                iteration[0] += 1

    else:
        for i, (inp, target) in enumerate(Dataset.train_loader):
            if args.is_multiheaded and not args.incremental_instance:
                # multiheaded incremental classes: move targets to head space
                target = target.clone()
                for i in range(target.size(0)):
                    target[i] = Dataset.maps_target_head[-1][target.numpy()[i]]

            # move data to device
            inp = inp.to(device)
            target = target.to(device)

            #print("inp:", inp.shape)
            #print("target:", target.shape)

            if epoch % args.epochs == 0 and i == 0:
                visualize_image_grid(inp, writer, epoch + 1,
                                     'train_inp_snapshot', save_path)

            if not args.is_segmentation:
                recon_target = inp
            else:
                # Split target to one hot encoding
                target_one_hot = to_one_hot(target.clone(),
                                            model.module.num_classes)
                # concat input and one_hot_target
                recon_target = torch.cat([inp, target_one_hot], dim=1)
            class_target = target

            # this needs to be below the line where the reconstruction target is set
            # sample and add noise to the input (but not to the target!).
            if args.denoising_noise_value > 0.0 and not args.no_vae:
                noise = torch.randn(
                    inp.size()).to(device) * args.denoising_noise_value
                inp = inp + noise

            # measure data loading time
            data_time.update(time.time() - end)

            # compute model forward
            class_samples, recon_samples, mu, std = model(inp)

            # if we have an autoregressive model variant, further calculate the corresponding layers.
            if args.autoregression:
                recon_samples_autoregression = torch.zeros(
                    recon_samples.size(0), inp.size(0), 256, inp.size(1),
                    inp.size(2), inp.size(3)).to(device)
                for j in range(model.module.num_samples):
                    recon_samples_autoregression[j] = model.module.pixelcnn(
                        recon_target,
                        torch.sigmoid(recon_samples[j])).contiguous()
                recon_samples = recon_samples_autoregression
                # set the target to work with the 256-way Softmax
                recon_target = (recon_target * 255).long()

            if args.is_multiheaded:
                if not args.incremental_instance:
                    class_loss, recon_loss, kld_loss = criterion(
                        class_samples[:, :, -args.num_increment_tasks:],
                        class_target, recon_samples, recon_target, mu, std,
                        device, args)
                else:
                    class_loss, recon_loss, kld_loss = criterion(
                        class_samples[:, :,
                                      -Dataset.num_classes:], class_target,
                        recon_samples, recon_target, mu, std, device, args)
            else:
                if not args.is_segmentation:
                    class_loss, recon_loss, kld_loss = criterion(
                        class_samples, class_target, recon_samples,
                        recon_target, mu, std, device, args)
                else:
                    class_loss, recon_loss, kld_loss = criterion(
                        class_samples,
                        class_target,
                        recon_samples,
                        recon_target,
                        mu,
                        std,
                        device,
                        args,
                        weight=Dataset.class_pixel_weight)

            # add the individual loss components together and weight the KL term.
            if args.no_vae:
                loss = class_loss
            else:
                loss = class_loss + recon_loss + args.var_beta * kld_loss

            # calculate lwf loss (if there is a previous model stored)
            if args.use_lwf and model.module.prev_model:
                # get prediction from previous model
                with torch.no_grad():
                    prev_pred_class_samples, _, _, _ = model.module.prev_model(
                        inp)
                prev_cl_losses = torch.zeros(
                    prev_pred_class_samples.size(0)).to(device)

                # loop through each sample for each input and calculate the correspond loss. Normalize the losses.
                for s in range(prev_pred_class_samples.size(0)):
                    if args.is_multiheaded:
                        if not args.incremental_instance:
                            prev_cl_losses[s] = loss_fn_kd_multihead(
                                class_samples[s]
                                [:, :-args.num_increment_tasks],
                                prev_pred_class_samples[s],
                                task_sizes=args.num_increment_tasks)
                        else:
                            prev_cl_losses[s] = loss_fn_kd_multihead(
                                class_samples[s][:, :-Dataset.num_classes],
                                prev_pred_class_samples[s],
                                task_sizes=Dataset.num_classes)
                    else:
                        if not args.is_segmentation:
                            prev_cl_losses[s] = loss_fn_kd(
                                class_samples[s], prev_pred_class_samples[s]
                            )  #/ torch.numel(target)
                        else:
                            prev_cl_losses[s] = loss_fn_kd_2d(
                                class_samples[s], prev_pred_class_samples[s]
                            )  #/ torch.numel(target)
                # average the loss over all samples per input
                cl_lwf = torch.mean(prev_cl_losses, dim=0)
                # add lwf loss to overall loss
                loss += args.lmda * cl_lwf
                # record lwf losses
                lwf_losses.update(cl_lwf.item(), inp.size(0))

            # calculate SI loss (if SI is initialized)
            if args.use_si and model.module.si_storage.is_initialized:
                if not args.is_segmentation:
                    loss_si = args.lmda * (
                        SI.surrogate_loss(model.module.encoder,
                                          model.module.si_storage) +
                        SI.surrogate_loss(model.module.latent_mu,
                                          model.module.si_storage_mu) +
                        SI.surrogate_loss(model.module.latent_std,
                                          model.module.si_storage_std))
                else:
                    loss_si = args.lmda * (
                        SI.surrogate_loss(model.module.encoder,
                                          model.module.si_storage) +
                        SI.surrogate_loss(model.module.bottleneck,
                                          model.module.si_storage_btn) +
                        SI.surrogate_loss(model.module.decoder,
                                          model.module.si_storage_dec))
                    #print("SI_Loss:", loss_si)

                loss += loss_si
                si_losses.update(loss_si.item(), inp.size(0))

            # take mean to compute accuracy. Note if variational samples are 1 this only gets rid of a dummy dimension.
            if args.is_multiheaded:
                if not args.incremental_instance:
                    output = torch.mean(
                        class_samples[:, :, -args.num_increment_tasks:], dim=0)
                else:
                    output = torch.mean(class_samples[:, :,
                                                      -Dataset.num_classes:],
                                        dim=0)
            else:
                output = torch.mean(class_samples, dim=0)

            # record precision/accuracy and losses
            if not args.is_segmentation:
                prec1 = accuracy(output, target)[0]
                top1.update(prec1.item(), inp.size(0))
            else:
                ious_cc = iou_class_condtitional(pred=output.clone(),
                                                 target=target.clone())
                #print("iou", ious_cc)
                prec1 = iou_to_accuracy(ious_cc)
                #print("prec1", prec1)
                top1.update(prec1.item(), inp.size(0))

            if args.no_vae:
                losses.update(class_loss.item(), inp.size(0))
                class_losses.update(class_loss.item(), inp.size(0))
            else:
                losses.update((class_loss + recon_loss + kld_loss).item(),
                              inp.size(0))
                class_losses.update(class_loss.item(), inp.size(0))
                recon_losses.update(recon_loss.item(), inp.size(0))
                kld_losses.update(kld_loss.item(), inp.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # SI: update running si paramters
            if args.use_si:
                if not args.is_segmentation:
                    SI.update_si_parameters(model.module.encoder,
                                            model.module.si_storage)
                    SI.update_si_parameters(model.module.latent_mu,
                                            model.module.si_storage_mu)
                    SI.update_si_parameters(model.module.latent_std,
                                            model.module.si_storage_std)
                else:
                    SI.update_si_parameters(model.module.encoder,
                                            model.module.si_storage)
                    SI.update_si_parameters(model.module.bottleneck,
                                            model.module.si_storage_btn)
                    SI.update_si_parameters(model.module.decoder,
                                            model.module.si_storage_dec)
                #print("SI: Updated running parameters")

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # print progress
            if i % args.print_freq == 0:
                if args.use_lwf and model.module.prev_model:
                    print(
                        'Training: [{0}][{1}/{2}]\t'
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                        'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t'
                        'LwF Loss {lwf_loss:.4f}\t'
                        'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                        'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t'
                        'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format(
                            epoch,
                            i,
                            len(Dataset.train_loader),
                            batch_time=batch_time,
                            data_time=data_time,
                            loss=losses,
                            cl_loss=class_losses,
                            top1=top1,
                            recon_loss=recon_losses,
                            KLD_loss=kld_losses,
                            lwf_loss=cl_lwf.item()))
                elif args.use_si and model.module.si_storage.is_initialized:
                    print(
                        'Training: [{0}][{1}/{2}]\t'
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                        'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t'
                        'SI Loss {si_loss:.4f}\t'
                        'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                        'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t'
                        'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format(
                            epoch,
                            i,
                            len(Dataset.train_loader),
                            batch_time=batch_time,
                            data_time=data_time,
                            loss=losses,
                            cl_loss=class_losses,
                            top1=top1,
                            recon_loss=recon_losses,
                            KLD_loss=kld_losses,
                            si_loss=loss_si.item()))
                else:
                    print(
                        'Training: [{0}][{1}/{2}]\t'
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                        'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t'
                        'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                        'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t'
                        'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format(
                            epoch,
                            i,
                            len(Dataset.train_loader),
                            batch_time=batch_time,
                            data_time=data_time,
                            loss=losses,
                            cl_loss=class_losses,
                            top1=top1,
                            recon_loss=recon_losses,
                            KLD_loss=kld_losses))

            # increase iteration
            iteration[0] += 1

    # TensorBoard summary logging
    writer.add_scalar('training/train_precision@1', top1.avg, epoch)
    writer.add_scalar('training/train_average_loss', losses.avg, epoch)
    writer.add_scalar('training/train_KLD', kld_losses.avg, epoch)
    writer.add_scalar('training/train_class_loss', class_losses.avg, epoch)
    writer.add_scalar('training/train_recon_loss', recon_losses.avg, epoch)

    writer.add_scalar('training/train_precision_itr@1', top1.avg, iteration[0])
    writer.add_scalar('training/train_average_loss_itr', losses.avg,
                      iteration[0])
    writer.add_scalar('training/train_KLD_itr', kld_losses.avg, iteration[0])
    writer.add_scalar('training/train_class_loss_itr', class_losses.avg,
                      iteration[0])
    writer.add_scalar('training/train_recon_loss_itr', recon_losses.avg,
                      iteration[0])

    if args.use_lwf:
        writer.add_scalar('training/train_lwf_loss', lwf_losses.avg,
                          iteration[0])
    if args.use_si:
        writer.add_scalar('training/train_si_loss', si_losses.avg,
                          iteration[0])

    # If the log weights argument is specified also add parameter and gradient histograms to TensorBoard.
    if args.log_weights:
        # Histograms and distributions of network parameters
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            writer.add_histogram(tag,
                                 value.data.cpu().numpy(),
                                 epoch,
                                 bins="auto")
            # second check required for buffers that appear in the parameters dict but don't receive gradients
            if value.requires_grad and value.grad is not None:
                writer.add_histogram(tag + '/grad',
                                     value.grad.data.cpu().numpy(),
                                     epoch,
                                     bins="auto")

    print(' * Train: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}'.format(
        loss=losses, top1=top1))