def validate(dataset, gen_model, disc_model, criterion, epoch, device, args, save_path_pictures): batch_time = AverageMeter() losses = AverageMeter() gen_losses = AverageMeter() disc_losses = AverageMeter() gen_model.eval() disc_model.eval() end = time.time() for i, (input, target) in enumerate(dataset.val_loader): with torch.no_grad(): input, target = input.to(device), input.to(device) if args.no_gpus > 1: input_size = gen_model.module.input_size else: input_size = gen_model.input_size # set inputs and targets z = torch.randn((input.size(0), input_size)).to(device) y_real, y_fake = torch.ones(input.size(0), 1).to(device), torch.zeros( input.size(0), 1).to(device) disc_real = disc_model(input) gen_out = gen_model(z) disc_fake = disc_model(gen_out) disc_real_loss = criterion(disc_real, y_real) disc_fake_loss = criterion(disc_fake, y_fake) disc_loss = disc_real_loss + disc_fake_loss disc_losses.update(disc_loss.item(), input.size(0)) gen_out = gen_model(z) disc_fake = disc_model(gen_out) gen_loss = criterion(disc_fake, y_real) gen_losses.update(gen_loss.item(), input.size(0)) if i % args.print_freq == 0: save_image((gen_out.data.view(-1, input.size(1), input.size(2), input.size(3))), os.path.join( save_path_pictures, 'sample_epoch_' + str(epoch) + '_ite_' + str(i + 1) + '.png')) del input, target, z, y_real, y_fake, disc_real, gen_out, disc_fake print(' * Validate: Generator Loss {gen_losses.avg:.3f} Discriminator Loss {disc_losses.avg:.3f}'\ .format(gen_losses=gen_losses, disc_losses=disc_losses)) print('-' * 80) return disc_losses.avg, gen_losses.avg
def validate(dataset, model, criterion, epoch, device, args): """ Evaluates/validates the model Parameters: dataset (torch.utils.data.TensorDataset): The dataset model (torch.nn.module): Model to be evaluated/validated criterion (torch.nn.criterion): Loss function epoch (int): Epoch counter device (str): device name where data is transferred to args (dict): Dictionary of (command line) arguments. Needs to contain print_freq (int). Returns: float: top1 accuracy """ cl_losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode model.eval() with torch.no_grad(): for i, (input, target) in enumerate(dataset.val_loader): input, target = input.to(device), target.to(device) # compute output output = model(input) # make targets one-hot for using BCEloss target_temp = target one_hot = torch.zeros(target.size(0), output.size(1)).to(device) one_hot.scatter_(1, target.long().view(target.size(0), -1), 1) target = one_hot # compute loss and accuracy loss = criterion(output, target) prec1, prec5 = accuracy(output, target_temp, (1, 5)) # measure accuracy and record loss cl_losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) # print( ' * Validation Task: \nLoss {loss.avg:.5f} Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' .format(loss=cl_losses, top1=top1, top5=top5)) print('=' * 80) return top1.avg
def validate(Dataset, model, criterion, epoch, writer, device, save_path, args): """ Evaluates/validates the model Parameters: Dataset (torch.utils.data.Dataset): The dataset model (torch.nn.module): Model to be evaluated/validated criterion (torch.nn.criterion): Loss function epoch (int): Epoch counter writer (tensorboard.SummaryWriter): TensorBoard writer instance device (str): device name where data is transferred to save_path (str): path to save data to args (dict): Dictionary of (command line) arguments. Needs to contain print_freq (int), epochs (int), incremental_data (bool), autoregression (bool), visualization_epoch (int), cross_dataset (bool), num_base_tasks (int), num_increment_tasks (int) and patch_size (int). Returns: float: top1 precision/accuracy float: average loss """ # initialize average meters to accumulate values losses = AverageMeter() class_losses = AverageMeter() inos_losses = AverageMeter() batch_time = AverageMeter() top1 = AverageMeter() # confusion matrix confusion = ConfusionMeter(model.module.num_classes, normalized=True) # switch to evaluate mode model.eval() end = time.time() # evaluate the entire validation dataset with torch.no_grad(): for i, (inp, target) in enumerate(Dataset.val_loader): inp = inp.to(device) target = target.to(device) recon_target = inp class_target = target[0] # compute output output, score = model(inp) # compute loss cl,rl = criterion(output, target, score, device, args) loss = cl + rl # measure accuracy, record loss, fill confusion matrix prec1 = accuracy(output, class_target)[0] top1.update(prec1.item(), inp.size(0)) class_losses.update(cl.item(), inp.size(0)) inos_losses.update(rl.item(), inp.size(0)) confusion.add(output.data, target) # measure elapsed time batch_time.update(time.time() - end) end = time.time() losses.update(loss.item(), inp.size(0)) # Print progress if i % args.print_freq == 0: print('Validate: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' .format( epoch+1, i, len(Dataset.val_loader), batch_time=batch_time, loss=losses, top1=top1)) # TensorBoard summary logging writer.add_scalar('validation/precision@1', top1.avg, epoch) writer.add_scalar('validation/average_loss', losses.avg, epoch) writer.add_scalar('validation/class_loss',class_losses.avg, epoch) writer.add_scalar('validation/inos_loss', inos_losses.avg, epoch) print(' * Validation: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}'.format(loss=losses, top1=top1)) # At the end of training isolated, or at the end of every task visualize the confusion matrix if (epoch + 1) % args.epochs == 0 and epoch > 0: # visualize the confusion matrix visualize_confusion(writer, epoch + 1, confusion.value(), Dataset.class_to_idx, save_path) return top1.avg, losses.avg
def train(train_loader, model, criterion, epoch, optimizer, lr_scheduler, device, args, split_batch_size): """ trains the model of a net for one epoch on the train set Parameters: train_loader (torch.utils.data.DataLoader): data loader for the train set model (lib.Models.network.Net): model of the net to be trained criterion (torch.nn.BCELoss): loss criterion to be optimized epoch (int): continuous epoch counter optimizer (torch.optim.SGD): optimizer instance like SGD or Adam lr_scheduler (lib.Training.learning_rate_scheduling.LearningRateScheduler): class implementing learning rate schedules device (torch.device): computational device (cpu or gpu) args (argparse.ArgumentParser): parsed command line arguments split_batch_size (int): smaller batch size after splitting the original batch size for fitting the device memory """ # performance and computational overhead metrics batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() hard_prec = AverageMeter() soft_prec = AverageMeter() # switch to train mode model.train() end = time.time() factor = args.batch_size // split_batch_size last_batch = int( math.ceil(len(train_loader.dataset) / float(split_batch_size))) optimizer.zero_grad() print('training') for i, (input_, target) in enumerate(train_loader): # hacky way to deal with terminal batch-size of 1 if input_.size(0) == 1: print('skip last training batch of size 1') continue input_, target = input_.to(device), target.to(device) data_time.update(time.time() - end) # adjust learning rate after every 'factor' times 'batch count' (after every batch had the batch size not been # split) if i % factor == 0: lr_scheduler.adjust_learning_rate(optimizer, i // factor + 1) output = model(input_) # scale the loss by the ratio of the split batch size and the original loss = criterion(output, target) * input_.size(0) / float( args.batch_size) # update the 'losses' meter with the actual measure of the loss losses.update(loss.item() * args.batch_size / float(input_.size(0)), input_.size(0)) # compute performance measures output = output >= 0.5 # binarizing sigmoid output by thresholding with 0.5 equality_matrix = (output.float() == target).float() hard = torch.mean(torch.prod(equality_matrix, dim=1)) * 100. soft = torch.mean(equality_matrix) * 100. # update peformance meters hard_prec.update(hard.item(), input_.size(0)) soft_prec.update(soft.item(), input_.size(0)) loss.backward() # update the weights after every 'factor' times 'batch count' (after every batch had the batch size not been # split) if (i + 1) % factor == 0 or i == (last_batch - 1): optimizer.step() optimizer.zero_grad() del output, input_, target # measure elapsed time batch_time.update(time.time() - end) end = time.time() # print performance and computational overhead measures after every 'factor' times 'batch count' (after every # batch had the batch size not been split) if i % (args.print_freq * factor) == 0: print( 'epoch: [{0}][{1}/{2}]\t' 'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'loss {losses.val:.3f} ({losses.avg:.3f})\t' 'hard prec {hard_prec.val:.3f} ({hard_prec.avg:.3f})\t' 'soft prec {soft_prec.val:.3f} ({soft_prec.avg:.3f})\t'.format( epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, losses=losses, hard_prec=hard_prec, soft_prec=soft_prec)) lr_scheduler.scheduler_epoch += 1 print( ' * train: loss {losses.avg:.3f} hard prec {hard_prec.avg:.3f} soft prec {soft_prec.avg:.3f}' .format(losses=losses, hard_prec=hard_prec, soft_prec=soft_prec)) print('*' * 80)
def validate(Dataset, model, criterion, epoch, writer, device, save_path, args): """ Evaluates/validates the model Parameters: Dataset (torch.utils.data.Dataset): The dataset model (torch.nn.module): Model to be evaluated/validated criterion (torch.nn.criterion): Loss function epoch (int): Epoch counter writer (tensorboard.SummaryWriter): TensorBoard writer instance device (str): device name where data is transferred to save_path (str): path to save data to args (dict): Dictionary of (command line) arguments. Needs to contain print_freq (int), epochs (int), incremental_data (bool), autoregression (bool), visualization_epoch (int), num_base_tasks (int), num_increment_tasks (int) and patch_size (int). Returns: float: top1 precision/accuracy float: average loss """ # initialize average meters to accumulate values class_losses = AverageMeter() recon_losses_nat = AverageMeter() kld_losses = AverageMeter() losses = AverageMeter() # for autoregressive models add an additional instance for reconstruction loss in bits per dimension if args.autoregression: recon_losses_bits_per_dim = AverageMeter() # for continual learning settings also add instances for base and new reconstruction metrics # corresponding accuracy values are calculated directly from the confusion matrix below if args.incremental_data and ((epoch + 1) % args.epochs == 0 and epoch > 0): recon_losses_new_nat = AverageMeter() recon_losses_base_nat = AverageMeter() if args.autoregression: recon_losses_new_bits_per_dim = AverageMeter() recon_losses_base_bits_per_dim = AverageMeter() batch_time = AverageMeter() top1 = AverageMeter() # confusion matrix confusion = ConfusionMeter(model.module.num_classes, normalized=True) # switch to evaluate mode model.eval() end = time.time() # evaluate the entire validation dataset with torch.no_grad(): for i, (inp, target) in enumerate(Dataset.val_loader): inp = inp.to(device) target = target.to(device) recon_target = inp class_target = target # compute output class_samples, recon_samples, mu, std = model(inp) # for autoregressive models convert the target to 0-255 integers and compute the autoregressive decoder # for each sample if args.autoregression: recon_target = (recon_target * 255).long() recon_samples_autoregression = torch.zeros(recon_samples.size(0), inp.size(0), 256, inp.size(1), inp.size(2), inp.size(3)).to(device) for j in range(model.module.num_samples): recon_samples_autoregression[j] = model.module.pixelcnn( inp, torch.sigmoid(recon_samples[j])).contiguous() recon_samples = recon_samples_autoregression # compute loss class_loss, recon_loss, kld_loss = criterion(class_samples, class_target, recon_samples, recon_target, mu, std, device, args) # For autoregressive models also update the bits per dimension value, converted from the obtained nats if args.autoregression: recon_losses_bits_per_dim.update(recon_loss.item() * math.log2(math.e), inp.size(0)) # take mean to compute accuracy # (does nothing if there isn't more than 1 sample per input other than removing dummy dimension) class_output = torch.mean(class_samples, dim=0) recon_output = torch.mean(recon_samples, dim=0) # measure accuracy, record loss, fill confusion matrix prec1 = accuracy(class_output, target)[0] top1.update(prec1.item(), inp.size(0)) confusion.add(class_output.data, target) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # for autoregressive models generate reconstructions by sequential sampling from the # multinomial distribution (Reminder: the original output is a 255 way Softmax as PixelVAEs are posed as a # classification problem). This serves two purposes: visualization of reconstructions and computation of # a reconstruction loss in nats using a BCE loss, comparable to that of a regular VAE. recon_target = inp if args.autoregression: recon = torch.zeros((inp.size(0), inp.size(1), inp.size(2), inp.size(3))).to(device) for h in range(inp.size(2)): for w in range(inp.size(3)): for c in range(inp.size(1)): probs = torch.softmax(recon_output[:, :, c, h, w], dim=1).data pixel_sample = torch.multinomial(probs, 1).float() / 255. recon[:, c, h, w] = pixel_sample.squeeze() if (epoch % args.visualization_epoch == 0) and (i == (len(Dataset.val_loader) - 1)) and (epoch > 0): visualize_image_grid(recon, writer, epoch + 1, 'reconstruction_snapshot', save_path) recon_loss = F.binary_cross_entropy(recon, recon_target) else: # If not autoregressive simply apply the Sigmoid and visualize recon = torch.sigmoid(recon_output) if (i == (len(Dataset.val_loader) - 1)) and (epoch % args.visualization_epoch == 0) and (epoch > 0): visualize_image_grid(recon, writer, epoch + 1, 'reconstruction_snapshot', save_path) # update the respective loss values. To be consistent with values reported in the literature we scale # our normalized losses back to un-normalized values. # For the KLD this also means the reported loss is not scaled by beta, to allow for a fair comparison # across potential weighting terms. class_losses.update(class_loss.item() * model.module.num_classes, inp.size(0)) kld_losses.update(kld_loss.item() * model.module.latent_dim, inp.size(0)) recon_losses_nat.update(recon_loss.item() * inp.size()[1:].numel(), inp.size(0)) losses.update((class_loss + recon_loss + kld_loss).item(), inp.size(0)) # if we are learning continually, we need to calculate the base and new reconstruction losses at the end # of each task increment. if args.incremental_data and ((epoch + 1) % args.epochs == 0 and epoch > 0): for j in range(inp.size(0)): # get the number of classes for class incremental scenarios. base_classes = model.module.seen_tasks[:args.num_base_tasks + 1] new_classes = model.module.seen_tasks[-args.num_increment_tasks:] if args.autoregression: rec = recon_output[j].view(1, recon_output.size(1), recon_output.size(2), recon_output.size(3), recon_output.size(4)) rec_tar = recon_target[j].view(1, recon_target.size(1), recon_target.size(2), recon_target.size(3)) # If the input belongs to one of the base classes also update base metrics if class_target[j].item() in base_classes: if args.autoregression: recon_losses_base_bits_per_dim.update(F.cross_entropy(rec, (rec_tar * 255).long()) * math.log2(math.e), 1) recon_losses_base_nat.update(F.binary_cross_entropy(recon[j], recon_target[j]), 1) # if the input belongs to one of the new classes also update new metrics elif class_target[j].item() in new_classes: if args.autoregression: recon_losses_new_bits_per_dim.update(F.cross_entropy(rec, (rec_tar * 255).long()) * math.log2(math.e), 1) recon_losses_new_nat.update(F.binary_cross_entropy(recon[j], recon_target[j]), 1) # If we are at the end of validation, create one mini-batch of example generations. Only do this every # other epoch specified by visualization_epoch to avoid generation of lots of images and computationally # expensive calculations of the autoregressive model's generation. if i == (len(Dataset.val_loader) - 1) and epoch % args.visualization_epoch == 0 and (epoch > 0): # generation gen = model.module.generate() if args.autoregression: gen = model.module.pixelcnn.generate(gen) visualize_image_grid(gen, writer, epoch + 1, 'generation_snapshot', save_path) # Print progress if i % args.print_freq == 0: print('Validate: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t' 'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format( epoch+1, i, len(Dataset.val_loader), batch_time=batch_time, loss=losses, cl_loss=class_losses, top1=top1, recon_loss=recon_losses_nat, KLD_loss=kld_losses)) # TensorBoard summary logging writer.add_scalar('validation/val_precision@1', top1.avg, epoch) writer.add_scalar('validation/val_average_loss', losses.avg, epoch) writer.add_scalar('validation/val_class_loss', class_losses.avg, epoch) writer.add_scalar('validation/val_recon_loss_nat', recon_losses_nat.avg, epoch) writer.add_scalar('validation/val_KLD', kld_losses.avg, epoch) if args.autoregression: writer.add_scalar('validation/val_recon_loss_bits_per_dim', recon_losses_bits_per_dim.avg, epoch) print(' * Validation: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}'.format(loss=losses, top1=top1)) # At the end of training isolated, or at the end of every task visualize the confusion matrix if (epoch + 1) % args.epochs == 0 and epoch > 0: # visualize the confusion matrix visualize_confusion(writer, epoch + 1, confusion.value(), Dataset.class_to_idx, save_path) # If we are in a continual learning scenario, also use the confusion matrix to extract base and new precision. if args.incremental_data: prec1_base = 0.0 prec1_new = 0.0 # this has to be + 1 because the number of initial tasks is always one less than the amount of classes # i.e. 1 task is 2 classes etc. for c in range(args.num_base_tasks + 1): prec1_base += confusion.value()[c][c] prec1_base = prec1_base / (args.num_base_tasks + 1) # For the first task "new" metrics are equivalent to "base" if (epoch + 1) / args.epochs == 1: prec1_new = prec1_base recon_losses_new_nat.avg = recon_losses_base_nat.avg if args.autoregression: recon_losses_new_bits_per_dim.avg = recon_losses_base_bits_per_dim.avg else: for c in range(args.num_increment_tasks): prec1_new += confusion.value()[-c-1][-c-1] prec1_new = prec1_new / args.num_increment_tasks # At the continual learning metrics to TensorBoard writer.add_scalar('validation/base_precision@1', prec1_base, len(model.module.seen_tasks)-1) writer.add_scalar('validation/new_precision@1', prec1_new, len(model.module.seen_tasks)-1) writer.add_scalar('validation/base_rec_loss_nats', recon_losses_base_nat.avg * args.patch_size * args.patch_size * model.module.num_colors, len(model.module.seen_tasks) - 1) writer.add_scalar('validation/new_rec_loss_nats', recon_losses_new_nat.avg * args.patch_size * args.patch_size * model.module.num_colors, len(model.module.seen_tasks) - 1) if args.autoregression: writer.add_scalar('validation/base_rec_loss_bits_per_dim', recon_losses_base_bits_per_dim.avg, len(model.module.seen_tasks) - 1) writer.add_scalar('validation/new_rec_loss_bits_per_dim', recon_losses_new_bits_per_dim.avg, len(model.module.seen_tasks) - 1) print(' * Incremental validation: Base Prec@1 {prec1_base:.3f} New Prec@1 {prec1_new:.3f}\t' 'Base Recon Loss {recon_losses_base_nat.avg:.3f} New Recon Loss {recon_losses_new_nat.avg:.3f}' .format(prec1_base=100*prec1_base, prec1_new=100*prec1_new, recon_losses_base_nat=recon_losses_base_nat, recon_losses_new_nat=recon_losses_new_nat)) return top1.avg, losses.avg
def train_var(Dataset, model, criterion, epoch, optimizer, writer, device, args): """ Trains/updates the model for one epoch on the training dataset. Parameters: Dataset (torch.utils.data.Dataset): The dataset model (torch.nn.module): Model to be trained criterion (torch.nn.criterion): Loss function epoch (int): Continuous epoch counter optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam writer (tensorboard.SummaryWriter): TensorBoard writer instance device (str): device name where data is transferred to args (dict): Dictionary of (command line) arguments. Needs to contain print_freq (int), denoising_noise_value (float) and var_beta (float). """ # Create instances to accumulate losses etc. cl_losses = AverageMeter() kld_losses = AverageMeter() losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() top1 = AverageMeter() # switch to train mode model.train() end = time.time() # train for i, (inp, target) in enumerate(Dataset.train_loader): inp = inp.to(device) target = target.to(device) # measure data loading time data_time.update(time.time() - end) # compute model forward output_samples, mu, std = model(inp) # calculate loss cl_loss, kld_loss = criterion(output_samples, target, mu, std, device) # add the individual loss components together and weight the KL term. loss = cl_loss + args.var_beta * kld_loss # take mean to compute accuracy. Note if variational samples are 1 this only gets rid of a dummy dimension. output = torch.mean(output_samples, dim=0) # record precision/accuracy and losses prec1 = accuracy(output, target)[0] top1.update(prec1.item(), inp.size(0)) losses.update((cl_loss + kld_loss).item(), inp.size(0)) cl_losses.update(cl_loss.item(), inp.size(0)) kld_losses.update(kld_loss.item(), inp.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() # print progress if i % args.print_freq == 0: print('Training: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format( epoch + 1, i, len(Dataset.train_loader), batch_time=batch_time, data_time=data_time, loss=losses, cl_loss=cl_losses, top1=top1, KLD_loss=kld_losses)) # TensorBoard summary logging writer.add_scalar('training/train_precision@1', top1.avg, epoch) writer.add_scalar('training/train_class_loss', cl_losses.avg, epoch) writer.add_scalar('training/train_average_loss', losses.avg, epoch) writer.add_scalar('training/train_KLD', kld_losses.avg, epoch) print(' * Train: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}'.format( loss=losses, top1=top1))
def train(Dataset, model, criterion, epoch, optimizer, writer, device, args): """ Trains/updates the model for one epoch on the training dataset. Parameters: Dataset (torch.utils.data.Dataset): The dataset model (torch.nn.module): Model to be trained criterion (torch.nn.criterion): Loss function epoch (int): Continuous epoch counter optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam writer (tensorboard.SummaryWriter): TensorBoard writer instance device (str): device name where data is transferred to args (dict): Dictionary of (command line) arguments. Needs to contain print_freq (int) and log_weights (bool). """ # Create instances to accumulate losses etc. losses = AverageMeter() class_losses = AverageMeter() inos_losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() top1 = AverageMeter() # switch to train mode model.train() end = time.time() # train for i, (inp, target) in enumerate(Dataset.train_loader): inp = inp.to(device) target = target.to(device) class_target = target[0] # measure data loading time data_time.update(time.time() - end) # compute model forward output, score = model(inp) # calculate loss cl, rl = criterion(output, target, score, device, args) loss = cl + rl # record precision/accuracy and losses prec1 = accuracy(output, class_target)[0] top1.update(prec1.item(), inp.size(0)) class_losses.update(cl.item(), inp.size(0)) inos_losses.update(rl.item(), inp.size(0)) losses.update(loss.item(), inp.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() # print progress if i % args.print_freq == 0: print('Training: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'.format( epoch + 1, i, len(Dataset.train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1)) # TensorBoard summary logging writer.add_scalar('train/precision@1', top1.avg, epoch) writer.add_scalar('train/average_loss', losses.avg, epoch) writer.add_scalar('train/class_loss', class_losses.avg, epoch) writer.add_scalar('train/inos_loss', inos_losses.avg, epoch) # If the log weights argument is specified also add parameter and gradient histograms to TensorBoard. if args.log_weights: # Histograms and distributions of network parameters for tag, value in model.named_parameters(): tag = tag.replace('.', '/') writer.add_histogram(tag, value.data.cpu().numpy(), epoch, bins="auto") # second check required for buffers that appear in the parameters dict but don't receive gradients if value.requires_grad and value.grad is not None: writer.add_histogram(tag + '/grad', value.grad.data.cpu().numpy(), epoch, bins="auto") print(' * Train: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}'.format( loss=losses, top1=top1))
def train(dataset, model, criterion, epoch, optimizer, lr_scheduler, device, args): """ Trains/updates the model for one epoch on the training dataset. Parameters: train_loader (torch.utils.data.DataLoader): The trainset dataloader model (torch.nn.module): Model to be trained criterion (torch.nn.criterion): Loss function epoch (int): Continuous epoch counter optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam lr_scheduler (Training.LearningRateScheduler): class implementing learning rate schedules device (str): device name where data is transferred to args (dict): Dictionary of (command line) arguments. Needs to contain learning_rate (float), momentum (float), weight_decay (float), nesterov momentum (bool), lr_dropstep (int), lr_dropfactor (float), print_freq (int) and expand (bool). """ batch_time = AverageMeter() data_time = AverageMeter() cl_losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to train mode model.train() end = time.time() for i, (input, target) in enumerate(dataset.train_loader): input, target = input.to(device), target.to(device) # measure data loading time data_time.update(time.time() - end) # adjust the learning rate if applicable lr_scheduler.adjust_learning_rate(optimizer, i + 1) # compute output output = model(input) # making targets one-hot for using BCEloss target_temp = target one_hot = torch.zeros(target.size(0), output.size(1)).to(device) one_hot.scatter_(1, target.long().view(target.size(0), -1), 1) target = one_hot # compute loss and accuracy loss = criterion(output, target) prec1, prec5 = accuracy(output, target_temp, (1, 5)) # measure accuracy and record loss cl_losses.update(loss.item(), input.size(0)) prec1, prec5 = accuracy(output, target_temp, (1, 5)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() del output, input, target # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format( epoch, i, len(dataset.train_loader), batch_time=batch_time, data_time=data_time, loss=cl_losses, top1=top1, top5=top5)) lr_scheduler.scheduler_epoch += 1 print(' * Train: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format( top1=top1, top5=top5)) print('=' * 80)
def validate_var(Dataset, model, criterion, epoch, writer, device, args): """ Evaluates/validates the model Parameters: Dataset (torch.utils.data.Dataset): The dataset model (torch.nn.module): Model to be evaluated/validated criterion (torch.nn.criterion): Loss function epoch (int): Epoch counter writer (tensorboard.SummaryWriter): TensorBoard writer instance device (str): device name where data is transferred to args (dict): Dictionary of (command line) arguments. Needs to contain print_freq (int), epochs (int) and patch_size (int). Returns: float: top1 precision/accuracy float: average loss """ # initialize average meters to accumulate values cl_losses = AverageMeter() kld_losses = AverageMeter() losses = AverageMeter() batch_time = AverageMeter() top1 = AverageMeter() # switch to evaluate mode model.eval() end = time.time() # evaluate the entire validation dataset with torch.no_grad(): for i, (inp, target) in enumerate(Dataset.val_loader): inp = inp.to(device) target = target.to(device) # compute output output_samples, mu, std = model(inp) # compute loss cl_loss, kld_loss = criterion(output_samples, target, mu, std, device) # take mean to compute accuracy # (does nothing if there isn't more than 1 sample per input other than removing dummy dimension) output = torch.mean(output_samples, dim=0) # measure and update accuracy prec1 = accuracy(output, target)[0] top1.update(prec1.item(), inp.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # update the respective loss values. To be consistent with values reported in the literature we scale # our normalized losses back to un-normalized values. # For the KLD this also means the reported loss is not scaled by beta, to allow for a fair comparison # across potential weighting terms. cl_losses.update(cl_loss.item() * model.module.num_classes, inp.size(0)) kld_losses.update(kld_loss.item() * model.module.latent_dim, inp.size(0)) losses.update((cl_loss + kld_loss).item(), inp.size(0)) # Print progress if i % args.print_freq == 0: print('Validate: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format( epoch + 1, i, len(Dataset.val_loader), batch_time=batch_time, loss=losses, cl_loss=cl_losses, top1=top1, KLD_loss=kld_losses)) # TensorBoard summary logging writer.add_scalar('validation/val_precision@1', top1.avg, epoch) writer.add_scalar('validation/val_class_loss', cl_losses.avg, epoch) writer.add_scalar('validation/val_average_loss', losses.avg, epoch) writer.add_scalar('validation/val_KLD', kld_losses.avg, epoch) print(' * Validation: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}'.format( loss=losses, top1=top1)) return top1.avg, losses.avg
def train(dataset, gen_model, disc_model, criterion, epoch, gen_optimizer, disc_optimizer, lr_scheduler, device, args): gen_losses = AverageMeter() disc_losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() gen_model.train() disc_model.train() end = time.time() for i, (input, target) in enumerate(dataset.train_loader): input, target = input.to(device), input.to(device) data_time.update(time.time() - end) lr_scheduler.adjust_learning_rate(gen_optimizer, i + 1) lr_scheduler.adjust_learning_rate(disc_optimizer, i + 1) if args.no_gpus > 1: input_size = gen_model.module.input_size else: input_size = gen_model.input_size # set inputs and targets z = torch.randn((input.size(0), input_size)).to(device) y_real, y_fake = torch.ones(input.size(0), 1).to(device), torch.zeros( input.size(0), 1).to(device) disc_real = disc_model(input) gen_out = gen_model(z) disc_fake = disc_model(gen_out) disc_real_loss = criterion(disc_real, y_real) disc_fake_loss = criterion(disc_fake, y_fake) disc_loss = disc_real_loss + disc_fake_loss disc_losses.update(disc_loss.item(), input.size(0)) disc_optimizer.zero_grad() disc_loss.backward() disc_optimizer.step() gen_out = gen_model(z) disc_fake = disc_model(gen_out) gen_loss = criterion(disc_fake, y_real) gen_losses.update(gen_loss.item(), input.size(0)) gen_optimizer.zero_grad() gen_loss.backward() gen_optimizer.step() del input, target, z, y_real, y_fake, disc_real, gen_out, disc_fake batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print( 'Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Generator Loss {gen_losses.val:.4f} ({gen_losses.avg:.4f})\t' 'Discriminator Loss {disc_losses.val:.3f} ({disc_losses.avg:.3f})\t' .format(epoch, i, len(dataset.train_loader), batch_time=batch_time, data_time=data_time, gen_losses=gen_losses, disc_losses=disc_losses)) print(' * Train: Generator Loss {gen_losses.avg:.3f} Discriminator Loss {disc_losses.avg:.3f}'\ .format(gen_losses=gen_losses, disc_losses=disc_losses)) print('-' * 80) return disc_losses.avg, gen_losses.avg
def validate(val_loader, model, criterion, epoch, device, args): """ Evaluates/validates the model Parameters: val_loader (torch.utils.data.DataLoader): The validation or testset dataloader model (torch.nn.module): Model to be evaluated/validated criterion (torch.nn.criterion): Loss function epoch (int): Epoch counter device (str): device name where data is transferred to args (dict): Dictionary of (command line) arguments. Needs to contain print_freq (int). Returns: float: top1 accuracy """ batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode model.eval() end = time.time() with torch.no_grad(): for i, (inp, target) in enumerate(val_loader): inp, target = inp.to(device), target.to(device) # compute output output = model(inp) # compute loss loss = criterion(output, target) # measure accuracy and record loss prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), inp.size(0)) top1.update(prec1.item(), inp.size(0)) top5.update(prec5.item(), inp.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Validate: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch + 1, i, len(val_loader), batch_time=batch_time, loss=losses, top1=top1, top5=top5)) print(' * Validation: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format( top1=top1, top5=top5)) return top1.avg
def val(val_loader, model, criterion, device, is_val=True): """ validates the model of a net for one epoch on the validation set Parameters: val_loader (torch.utils.data.DataLoader): data loader for the validation set model (lib.Models.network.Net): model of a net that has to be validated criterion (torch.nn.BCELoss): loss criterion device (torch.device): device name where data is transferred to is_val (bool): validation or testing mode Returns: losses.avg (float): average of the validation losses over the batches hard_prec.avg (float): average of the validation hard precision over the batches for all the defects soft_prec.avg (float): average of the validation soft precision over the batches for all the defects hard_prec_background.avg (float): average of the validation hard/soft precision over the batches for background hard_prec_crack.avg (float): average of the validation hard/soft precision over the batches for crack hard_prec_spallation.avg (float): average of the validation hard/soft precision over the batches for spallation hard_prec_exposed_bars.avg (float): average of the validation hard/soft precision over the batches for exposed bars hard_prec_efflorescence.avg (float): average of the validation hard/soft precision over the batches for efflorescence hard_prec_corrosion_stain.avg (float): average of the validation hard/soft precision over the batches for corrosion stain """ # performance metrics losses = AverageMeter() hard_prec = AverageMeter() soft_prec = AverageMeter() hard_prec_background = AverageMeter() hard_prec_crack = AverageMeter() hard_prec_spallation = AverageMeter() hard_prec_exposed_bars = AverageMeter() hard_prec_efflorescence = AverageMeter() hard_prec_corrosion_stain = AverageMeter() # for computing correct prediction percentage for multi-label samples and single-label samples number_multi_label_samples = number_correct_multi_label_predictions = number_single_label_samples =\ number_correct_single_label_predictions = 0 # switch to evaluate mode model.eval() if is_val: print('validating') else: print('testing') # to ensure no buffering for gradient updates with torch.no_grad(): for i, (input_, target) in enumerate(val_loader): if input_.size(0) == 1: # hacky way to deal with terminal batch-size of 1 print('skip last val/test batch of size 1') continue input_, target = input_.to(device), target.to(device) output = model(input_) loss = criterion(output, target) # update the 'losses' meter losses.update(loss.item(), input_.size(0)) # compute performance measures output = output >= 0.5 # binarizing sigmoid output by thresholding with 0.5 temp_output = output temp_output = temp_output.float() + (((torch.sum( temp_output.float(), dim=1, keepdim=True) == 0).float()) * 0.5) # compute correct prediction percentage for multi-label/single-label samples sum_target_along_label_dimension = torch.sum(target, dim=1, keepdim=True) multi_label_samples = (sum_target_along_label_dimension > 1).float() target_multi_label_samples = (target.float()) * multi_label_samples number_multi_label_samples += torch.sum(multi_label_samples).item() number_correct_multi_label_predictions += \ torch.sum(torch.prod((temp_output.float() == target_multi_label_samples).float(), dim=1)).item() single_label_samples = ( sum_target_along_label_dimension == 1).float() target_single_label_samples = ( target.float()) * single_label_samples number_single_label_samples += torch.sum( single_label_samples).item() number_correct_single_label_predictions += \ torch.sum(torch.prod((temp_output.float() == target_single_label_samples).float(), dim=1)).item() equality_matrix = (output.float() == target).float() hard = torch.mean(torch.prod(equality_matrix, dim=1)) * 100. soft = torch.mean(equality_matrix) * 100. hard_per_defect = torch.mean(equality_matrix, dim=0) * 100. # update peformance meters hard_prec.update(hard.item(), input_.size(0)) soft_prec.update(soft.item(), input_.size(0)) hard_prec_background.update(hard_per_defect[0].item(), input_.size(0)) hard_prec_crack.update(hard_per_defect[1].item(), input_.size(0)) hard_prec_spallation.update(hard_per_defect[2].item(), input_.size(0)) hard_prec_exposed_bars.update(hard_per_defect[3].item(), input_.size(0)) hard_prec_efflorescence.update(hard_per_defect[4].item(), input_.size(0)) hard_prec_corrosion_stain.update(hard_per_defect[5].item(), input_.size(0)) percentage_single_labels = (100. * number_correct_single_label_predictions ) / number_single_label_samples percentage_multi_labels = (100. * number_correct_multi_label_predictions ) / number_multi_label_samples if is_val: print( ' * val: loss {losses.avg:.3f}, hard prec {hard_prec.avg:.3f}, soft prec {soft_prec.avg:.3f},\t' '% correct single-label predictions {percentage_single_labels:.3f}, % correct multi-label predictions' '{percentage_multi_labels:.3f},\t' 'hard prec background {hard_prec_background.avg:.3f}, hard prec crack {hard_prec_crack.avg:.3f},\t' 'hard prec spallation {hard_prec_spallation.avg:.3f}, ' 'hard prec exposed bars {hard_prec_exposed_bars.avg:.3f},\t' 'hard prec efflorescence {hard_prec_efflorescence.avg:.3f}, hard prec corrosion stain' ' {hard_prec_corrosion_stain.avg:.3f}\t'.format( losses=losses, hard_prec=hard_prec, soft_prec=soft_prec, percentage_single_labels=percentage_single_labels, percentage_multi_labels=percentage_multi_labels, hard_prec_background=hard_prec_background, hard_prec_crack=hard_prec_crack, hard_prec_spallation=hard_prec_spallation, hard_prec_exposed_bars=hard_prec_exposed_bars, hard_prec_efflorescence=hard_prec_efflorescence, hard_prec_corrosion_stain=hard_prec_corrosion_stain)) print('*' * 80) else: print( ' * test: loss {losses.avg:.3f}, hard prec {hard_prec.avg:.3f}, soft prec {soft_prec.avg:.3f},\t' '% correct single-label predictions {percentage_single_labels:.3f}, % correct multi-label predictions' '{percentage_multi_labels:.3f},\t' 'hard prec background {hard_prec_background.avg:.3f}, hard prec crack {hard_prec_crack.avg:.3f},\t' 'hard prec spallation {hard_prec_spallation.avg:.3f}, ' 'hard prec exposed bars {hard_prec_exposed_bars.avg:.3f},\t' 'hard prec efflorescence {hard_prec_efflorescence.avg:.3f}, hard prec corrosion stain' '{hard_prec_corrosion_stain.avg:.3f}\t'.format( losses=losses, hard_prec=hard_prec, soft_prec=soft_prec, percentage_single_labels=percentage_single_labels, percentage_multi_labels=percentage_multi_labels, hard_prec_background=hard_prec_background, hard_prec_crack=hard_prec_crack, hard_prec_spallation=hard_prec_spallation, hard_prec_exposed_bars=hard_prec_exposed_bars, hard_prec_efflorescence=hard_prec_efflorescence, hard_prec_corrosion_stain=hard_prec_corrosion_stain)) print('*' * 80) if is_val: return losses.avg, hard_prec.avg, soft_prec.avg, hard_prec_background.avg, hard_prec_crack.avg,\ hard_prec_spallation.avg, hard_prec_exposed_bars.avg, hard_prec_exposed_bars.avg,\ hard_prec_corrosion_stain.avg else: return None
def train(Dataset, model, criterion, epoch, optimizer, writer, device, args): """ Trains/updates the model for one epoch on the training dataset. Parameters: Dataset (torch.utils.data.Dataset): The dataset model (torch.nn.module): Model to be trained criterion (torch.nn.criterion): Loss function epoch (int): Continuous epoch counter optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam writer (tensorboard.SummaryWriter): TensorBoard writer instance device (str): device name where data is transferred to args (dict): Dictionary of (command line) arguments. Needs to contain print_freq (int) and log_weights (bool). """ # Create instances to accumulate losses etc. class_losses = AverageMeter() recon_losses = AverageMeter() kld_losses = AverageMeter() losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() top1 = AverageMeter() # switch to train mode model.train() end = time.time() # train for i, (inp, target) in enumerate(Dataset.train_loader): inp = inp.to(device) target = target.to(device) recon_target = inp class_target = target # this needs to be below the line where the reconstruction target is set # sample and add noise to the input (but not to the target!). if args.denoising_noise_value > 0.0: noise = torch.randn( inp.size()).to(device) * args.denoising_noise_value inp = inp + noise # measure data loading time data_time.update(time.time() - end) # compute model forward class_samples, recon_samples, mu, std = model(inp) # if we have an autoregressive model variant, further calculate the corresponding layers. if args.autoregression: recon_samples_autoregression = torch.zeros(recon_samples.size(0), inp.size(0), 256, inp.size(1), inp.size(2), inp.size(3)).to(device) for j in range(model.module.num_samples): recon_samples_autoregression[j] = model.module.pixelcnn( recon_target, torch.sigmoid(recon_samples[j])).contiguous() recon_samples = recon_samples_autoregression # set the target to work with the 256-way Softmax recon_target = (recon_target * 255).long() # calculate loss class_loss, recon_loss, kld_loss = criterion(class_samples, class_target, recon_samples, recon_target, mu, std, device, args) # add the individual loss components together and weight the KL term. loss = class_loss + recon_loss + args.var_beta * kld_loss # take mean to compute accuracy. Note if variational samples are 1 this only gets rid of a dummy dimension. output = torch.mean(class_samples, dim=0) # record precision/accuracy and losses prec1 = accuracy(output, target)[0] top1.update(prec1.item(), inp.size(0)) losses.update((class_loss + recon_loss + kld_loss).item(), inp.size(0)) class_losses.update(class_loss.item(), inp.size(0)) recon_losses.update(recon_loss.item(), inp.size(0)) kld_losses.update(kld_loss.item(), inp.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() # print progress if i % args.print_freq == 0: print('Training: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t' 'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format( epoch + 1, i, len(Dataset.train_loader), batch_time=batch_time, data_time=data_time, loss=losses, cl_loss=class_losses, top1=top1, recon_loss=recon_losses, KLD_loss=kld_losses)) # TensorBoard summary logging writer.add_scalar('training/train_precision@1', top1.avg, epoch) writer.add_scalar('training/train_average_loss', losses.avg, epoch) writer.add_scalar('training/train_KLD', kld_losses.avg, epoch) writer.add_scalar('training/train_class_loss', class_losses.avg, epoch) writer.add_scalar('training/train_recon_loss', recon_losses.avg, epoch) # If the log weights argument is specified also add parameter and gradient histograms to TensorBoard. if args.log_weights: # Histograms and distributions of network parameters for tag, value in model.named_parameters(): tag = tag.replace('.', '/') writer.add_histogram(tag, value.data.cpu().numpy(), epoch, bins="auto") # second check required for buffers that appear in the parameters dict but don't receive gradients if value.requires_grad and value.grad is not None: writer.add_histogram(tag + '/grad', value.grad.data.cpu().numpy(), epoch, bins="auto") print(' * Train: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}'.format( loss=losses, top1=top1))
def train(dataset, model, criterion, epoch, optimizer, lr_scheduler, device, args): """ Trains/updates the model for one epoch on the training dataset. Parameters: train_loader (torch.utils.data.DataLoader): The trainset dataloader model (torch.nn.module): Model to be trained criterion (torch.nn.criterion): Loss function epoch (int): Continuous epoch counter optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam lr_scheduler (Training.LearningRateScheduler): class implementing learning rate schedules device (str): device name where data is transferred to args (dict): Dictionary of (command line) arguments. Needs to contain learning_rate (float), momentum (float), weight_decay (float), nesterov momentum (bool), lr_dropstep (int), lr_dropfactor (float), print_freq (int) and expand (bool). """ batch_time = AverageMeter() data_time = AverageMeter() elbo_losses = AverageMeter() rec_losses = AverageMeter() kl_losses = AverageMeter() # switch to train mode model.train() end = time.time() for i, (input, target) in enumerate(dataset.train_loader): input, target = input.to(device), input.to(device) # measure data loading time data_time.update(time.time() - end) # adjust the learning rate if applicable lr_scheduler.adjust_learning_rate(optimizer, i + 1) # compute output, mean and std output, mean, std = model(input) # compute loss kl_loss = -0.5*torch.mean(1+torch.log(1e-8+std**2)-(mean**2)-(std**2)) rec_loss = criterion(output, target) elbo_loss = rec_loss + kl_loss # record loss rec_losses.update(rec_loss.item(), input.size(0)) kl_losses.update(kl_loss.item(), input.size(0)) elbo_losses.update(elbo_loss.item(), input.size(0)) # compute gradient and do SGD step optimizer.zero_grad() elbo_loss.backward() optimizer.step() del output, input, target # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'ELBO Loss {elbo_losses.val:.4f} ({elbo_losses.avg:.4f})\t' 'Reconstruction Loss {rec_losses.val:.3f} ({rec_losses.avg:.3f})\t' 'KL Divergence {kl_losses.val:.3f} ({kl_losses.avg:.3f})\t'.format( epoch, i, len(dataset.train_loader), batch_time=batch_time, data_time=data_time, elbo_losses=elbo_losses, rec_losses=rec_losses, kl_losses=kl_losses)) lr_scheduler.scheduler_epoch += 1 print(' * Train: ELBO Loss {elbo_losses.avg:.3f} Reconstruction Loss {rec_losses.avg:.3f} KL Divergence {kl_losses.avg:.3f}'\ .format(elbo_losses=elbo_losses, rec_losses=rec_losses, kl_losses=kl_losses)) print('-' * 80)
def validate(dataset, model, criterion, epoch, device, args, save_path_pictures): """ Trains/updates the model for one epoch on the training dataset. Parameters: train_loader (torch.utils.data.DataLoader): The trainset dataloader model (torch.nn.module): Model to be trained criterion (torch.nn.criterion): Loss function epoch (int): Continuous epoch counter device (str): device name where data is transferred to args (dict): Dictionary of (command line) arguments. Needs to contain learning_rate (float), momentum (float), weight_decay (float), nesterov momentum (bool), lr_dropstep (int), lr_dropfactor (float), print_freq (int) and expand (bool). """ elbo_losses = AverageMeter() rec_losses = AverageMeter() kl_losses = AverageMeter() # switch to train mode model.eval() for i, (input, target) in enumerate(dataset.train_loader): with torch.no_grad(): input, target = input.to(device), input.to(device) # compute output, mean and std output, mean, std = model(input) # compute loss kl_loss = -0.5 * torch.mean(1 + torch.log(1e-8 + std**2) - (mean**2) - (std**2)) rec_loss = criterion(output, target) elbo_loss = rec_loss + kl_loss # record loss rec_losses.update(rec_loss.item(), input.size(0)) kl_losses.update(kl_loss.item(), input.size(0)) elbo_losses.update(elbo_loss.item(), input.size(0)) if i % args.print_freq == 0: save_image((input.data).view(-1, input.size(1), input.size(2), input.size(3)), os.path.join( save_path_pictures, 'input_epoch_' + str(epoch) + '_ite_' + str(i + 1) + '.png')) save_image( (output.data).view(-1, output.size(1), output.size(2), output.size(3)), os.path.join( save_path_pictures, 'epoch_' + str(epoch) + '_ite_' + str(i + 1) + '.png')) if args.no_gpus > 1: sample = torch.randn(input.size(0), model.module.latent_size).to(device) sample = model.module.sample(sample).cpu() else: sample = torch.randn(input.size(0), model.latent_size).to(device) sample = model.sample(sample).cpu() save_image((sample.view(-1, input.size(1), input.size(2), input.size(3))), os.path.join( save_path_pictures, 'sample_epoch_' + str(epoch) + '_ite_' + str(i + 1) + '.png')) del output, input, target print(' * Validate: ELBO Loss {elbo_losses.avg:.3f} Reconstruction Loss {rec_losses.avg:.3f} KL Divergence {kl_losses.avg:.3f}'\ .format(elbo_losses=elbo_losses, rec_losses=rec_losses, kl_losses=kl_losses)) print('-' * 80) return 1. / elbo_losses.avg
def train(Dataset, model, criterion, epoch, optimizer, writer, device, save_path, args): """ Trains/updates the model for one epoch on the training dataset. Parameters: Dataset (torch.utils.data.Dataset): The dataset model (torch.nn.module): Model to be trained criterion (torch.nn.criterion): Loss function epoch (int): Continuous epoch counter optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam writer (tensorboard.SummaryWriter): TensorBoard writer instance device (str): device name where data is transferred to args (dict): Dictionary of (command line) arguments. Needs to contain print_freq (int) and log_weights (bool). """ # Create instances to accumulate losses etc. class_losses = AverageMeter() recon_losses = AverageMeter() kld_losses = AverageMeter() losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() G_losses = AverageMeter() gp_losses = AverageMeter() D_losses = AverageMeter() D_losses_real = AverageMeter() D_losses_fake = AverageMeter() G_losses_fake = AverageMeter() fake_class_losses = AverageMeter() top1 = AverageMeter() # switch to train mode model.train() GAN_criterion = torch.nn.BCEWithLogitsLoss() end = time.time() # train for i, (inp, target) in enumerate(Dataset.train_loader): inp = inp.to(device) target = target.to(device) recon_target = inp class_target = target # this needs to be below the line where the reconstruction target is set # sample and add noise to the input (but not to the target!). if args.denoising_noise_value > 0.0: noise = torch.randn( inp.size()).to(device) * args.denoising_noise_value inp = inp + noise if args.blur: inp = blur_data(inp, args.patch_size, device) # measure data loading time data_time.update(time.time() - end) # Model explanation: Conventionally GAN architecutre update D first and G #### D Update#### class_samples, recon_samples, mu, std = model(inp) mu_label = None if args.proj_gan: # pred_label = torch.argmax(class_samples, dim=-1).squeeze() # mu_label = pred_label.to(device) mu_label = target real_z = model.module.forward_D(recon_target, mu_label) D_loss_real = torch.mean(torch.nn.functional.relu(1. - real_z)) #hinge # D_loss_real = - torch.mean(real_z) #WGAN-GP D_losses_real.update(D_loss_real.item(), 1) n, b, c, x, y = recon_samples.shape fake_z = model.module.forward_D( (recon_samples.view(n * b, c, x, y)).detach(), mu_label) D_loss_fake = torch.mean(torch.nn.functional.relu(1. + fake_z)) #hinge # D_loss_fake = torch.mean(fake_z) #WGAN-GP D_losses_fake.update(D_loss_fake.item(), 1) GAN_D_loss = (D_loss_real + D_loss_fake) D_losses.update(GAN_D_loss.item(), inp.size(0)) # Compute loss for gradient penalty alpha = torch.rand(recon_target.size(0), 1, 1, 1).to(device) x_hat = (alpha * recon_target + (1 - alpha) * recon_samples.view(n * b, c, x, y)).requires_grad_(True) out_x_hat = model.module.forward_D(x_hat, mu_label) D_loss_gp = model.module.discriminator.gradient_penalty( out_x_hat, x_hat) gp_losses.update(D_loss_gp, inp.size(0)) GAN_D_loss += args.lambda_gp * D_loss_gp # compute gradient and do SGD step optimizer['enc'].zero_grad() optimizer['dec'].zero_grad() optimizer['disc'].zero_grad() GAN_D_loss.backward() optimizer['disc'].step() #### G Update#### if i % 1 == 0: class_samples, recon_samples, mu, std = model(inp) mu_label = None if args.proj_gan: # pred_label = torch.argmax(class_samples, dim=-1).squeeze() # mu_label = pred_label.to(device) mu_label = target # OCDVAE calculate loss class_loss, recon_loss, kld_loss = criterion( class_samples, class_target, recon_samples, recon_target, mu, std, device, args) # add the individual loss components together and weight the KL term. loss = class_loss + args.l1_weight * recon_loss + args.var_beta * kld_loss output = torch.mean(class_samples, dim=0) # record precision/accuracy and losses prec1 = accuracy(output, target)[0] top1.update(prec1.item(), inp.size(0)) losses.update(loss.item(), inp.size(0)) class_losses.update(class_loss.item(), inp.size(0)) recon_losses.update(recon_loss.item(), inp.size(0)) kld_losses.update(kld_loss.item(), inp.size(0)) # Needed to add GAN_criterion on KL n, b, c, x, y = recon_samples.shape fake_z = model.module.forward_D( (recon_samples.view(n * b, c, x, y)), mu_label) GAN_G_loss = -torch.mean(fake_z) G_losses_fake.update(GAN_G_loss.item(), inp.size(0)) GAN_G_loss = -args.var_gan_weight * torch.mean(fake_z) G_losses.update(GAN_G_loss.item(), inp.size(0)) GAN_G_loss += loss optimizer['enc'].zero_grad() optimizer['dec'].zero_grad() optimizer['disc'].zero_grad() GAN_G_loss.backward() optimizer['enc'].step() optimizer['dec'].step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() # print progress if i % args.print_freq == 0: print("OCD_VAELoss: ") print('Training: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t' 'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format( epoch + 1, i, len(Dataset.train_loader), batch_time=batch_time, data_time=data_time, loss=losses, cl_loss=class_losses, top1=top1, recon_loss=recon_losses, KLD_loss=kld_losses)) print("GANLoss: ") print('G Loss {G_loss.val:.4f} ({G_loss.avg:.4f})\t' 'D Loss {D_loss.val:.4f} ({D_loss.avg:.4f})'.format( G_loss=G_losses, D_loss=D_losses)) if (i == (len(Dataset.train_loader) - 2)) and (epoch % args.visualization_epoch == 0): visualize_image_grid(inp, writer, epoch + 1, 'train_input_snapshot', save_path) visualize_image_grid(recon_samples.view(n * b, c, x, y), writer, epoch + 1, 'train_reconstruction_snapshot', save_path) # TensorBoard summary logging writer.add_scalar('training/train_precision@1', top1.avg, epoch) writer.add_scalar('training/train_average_loss', losses.avg, epoch) writer.add_scalar('training/train_KLD', kld_losses.avg, epoch) writer.add_scalar('training/train_class_loss', class_losses.avg, epoch) writer.add_scalar('training/train_recon_loss', recon_losses.avg, epoch) writer.add_scalar('training/train_G_loss', G_losses.avg, epoch) writer.add_scalar('training/train_D_loss', D_losses.avg, epoch) writer.add_scalar('training/train_D_loss_real', D_losses_real.avg, epoch) writer.add_scalar('training/train_D_loss_fake', D_losses_fake.avg, epoch) writer.add_scalar('training/train_G_loss_fake', G_losses_fake.avg, epoch) # If the log weights argument is specified also add parameter and gradient histograms to TensorBoard. if args.log_weights: # Histograms and distributions of network parameters for tag, value in model.named_parameters(): tag = tag.replace('.', '/') writer.add_histogram(tag, value.data.cpu().numpy(), epoch, bins="auto") # second check required for buffers that appear in the parameters dict but don't receive gradients if value.requires_grad and value.grad is not None: writer.add_histogram(tag + '/grad', value.grad.data.cpu().numpy(), epoch, bins="auto") print( ' * Train: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}\t' ' GAN: Generator {G_losses.avg:.5f} Discriminator {D_losses.avg:.4f}'. format(loss=losses, top1=top1, G_losses=G_losses, D_losses=D_losses))
def train(Dataset, model, criterion, epoch, optimizer, writer, device, args): """ Trains/updates the model for one epoch on the training dataset. Parameters: Dataset (torch.utils.data.Dataset): The dataset model (torch.nn.module): Model to be trained criterion (torch.nn.criterion): Loss function epoch (int): Continuous epoch counter optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam writer (tensorboard.SummaryWriter): TensorBoard writer instance device (str): device name where data is transferred to args (dict): Dictionary of (command line) arguments. Needs to contain print_freq (int) and log_weights (bool). """ # Create instances to accumulate losses etc. class_losses = AverageMeter() recon_losses = AverageMeter() if args.introspection: kld_real_losses = AverageMeter() kld_fake_losses = AverageMeter() kld_rec_losses = AverageMeter() criterion_enc = criterion[0] criterion_dec = criterion[1] optimizer_enc = optimizer[0] optimizer_dec = optimizer[1] else: kld_losses = AverageMeter() losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() top1 = AverageMeter() # switch to train mode model.train() end = time.time() # train for i, (inp, target) in enumerate(Dataset.train_loader): if args.data_augmentation: inp = augmentation.augment_data(inp, args) inp = inp.to(device) target = target.to(device) recon_target = inp class_target = target # this needs to be below the line where the reconstruction target is set # sample and add noise to the input (but not to the target!). if args.denoising_noise_value > 0.0: noise = torch.randn(inp.size()).to(device) * args.denoising_noise_value inp = inp + noise if args.blur: inp = blur_data(inp, args.patch_size, device) # measure data loading time data_time.update(time.time() - end) if args.introspection: # Update encoder real_mu, real_std = model.module.encode(inp) z = model.module.reparameterize(real_mu, real_std) class_output = model.module.classifier(z) recon = torch.sigmoid(model.module.decode(z)) model.eval() recon_mu, recon_std = model.module.encode(recon.detach()) model.train() z_p = torch.randn(inp.size(0), model.module.latent_dim).to(model.module.device) recon_p = torch.sigmoid(model.module.decode(z_p)) model.eval() mu_p, std_p = model.module.encode(recon_p.detach()) model.train() cl, rl, kld_real, kld_rec, kld_fake = criterion_enc(class_output, class_target, recon, recon_target, real_mu, real_std, recon_mu, recon_std, mu_p, std_p, device, args) kld_real_losses.update(kld_real.item(), inp.size(0)) alpha = 1 if not args.gray_scale: alpha = 3 loss_encoder = cl + alpha * rl + args.var_beta * (kld_real + 0.5 * (kld_rec + kld_fake) * args.gamma) optimizer_enc.zero_grad() loss_encoder.backward() optimizer_enc.step() # update decoder recon = torch.sigmoid(model.module.decode(z.detach())) model.eval() recon_mu, recon_std = model.module.encode(recon.detach()) model.train() recon_p = torch.sigmoid(model.module.decode(z_p)) model.eval() fake_mu, fake_std = model.module.encode(recon_p.detach()) model.train() rl, kld_rec, kld_fake = criterion_dec(recon, recon_target, recon_mu, recon_std, fake_mu, fake_std, args) loss_decoder = 0.5 * (kld_rec + kld_fake) * args.gamma + rl * alpha optimizer_dec.zero_grad() loss_decoder.backward() optimizer_dec.step() losses.update((loss_encoder + loss_decoder).item(), inp.size(0)) class_losses.update(cl.item(), inp.size(0)) recon_losses.update(rl.item(), inp.size(0)) kld_rec_losses.update(kld_rec.item(), inp.size(0)) kld_fake_losses.update(kld_fake.item(), inp.size(0)) # record precision/accuracy and losses prec1 = accuracy(class_output, target)[0] top1.update(prec1.item(), inp.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # print progress if i % args.print_freq == 0: print('Training: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t' 'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format( epoch + 1, i, len(Dataset.train_loader), batch_time=batch_time, data_time=data_time, loss=losses, cl_loss=class_losses, top1=top1, recon_loss=recon_losses, KLD_loss=kld_real_losses)) else: # compute model forward class_samples, recon_samples, mu, std = model(inp) # if we have an autoregressive model variant, further calculate the corresponding layers. if args.autoregression: recon_samples_autoregression = torch.zeros(recon_samples.size(0), inp.size(0), 256, inp.size(1), inp.size(2), inp.size(3)).to(device) for j in range(model.module.num_samples): recon_samples_autoregression[j] = model.module.pixelcnn(recon_target, torch.sigmoid(recon_samples[j])).contiguous() recon_samples = recon_samples_autoregression # set the target to work with the 256-way Softmax recon_target = (recon_target * 255).long() # calculate loss class_loss, recon_loss, kld_loss = criterion(class_samples, class_target, recon_samples, recon_target, mu, std, device, args) # add the individual loss components together and weight the KL term. loss = class_loss + recon_loss + args.var_beta * kld_loss # take mean to compute accuracy. Note if variational samples are 1 this only gets rid of a dummy dimension. class_output = torch.mean(class_samples, dim=0) # record precision/accuracy and losses losses.update((class_loss + recon_loss + kld_loss).item(), inp.size(0)) class_losses.update(class_loss.item(), inp.size(0)) recon_losses.update(recon_loss.item(), inp.size(0)) kld_losses.update(kld_loss.item(), inp.size(0)) prec1 = accuracy(class_output, target)[0] top1.update(prec1.item(), inp.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() # print progress if i % args.print_freq == 0: print('Training: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t' 'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format( epoch+1, i, len(Dataset.train_loader), batch_time=batch_time, data_time=data_time, loss=losses, cl_loss=class_losses, top1=top1, recon_loss=recon_losses, KLD_loss=kld_losses)) # TensorBoard summary logging if args.introspection: writer.add_scalar('training/train_precision@1', top1.avg, epoch) writer.add_scalar('training/train_average_loss', losses.avg, epoch) writer.add_scalar('training/train_KLD_real', kld_real_losses.avg, epoch) writer.add_scalar('training/train_class_loss', class_losses.avg, epoch) writer.add_scalar('training/train_KLD_rec', kld_rec_losses.avg, epoch) writer.add_scalar('training/train_KLD_fake', kld_fake_losses.avg, epoch) else: writer.add_scalar('training/train_precision@1', top1.avg, epoch) writer.add_scalar('training/train_average_loss', losses.avg, epoch) writer.add_scalar('training/train_KLD', kld_losses.avg, epoch) writer.add_scalar('training/train_class_loss', class_losses.avg, epoch) writer.add_scalar('training/train_recon_loss', recon_losses.avg, epoch) print(' * Train: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}'.format(loss=losses, top1=top1))
def train(train_loader, model, criterion, epoch, optimizer, lr_scheduler, device, batch_split_size, args): """ Trains/updates the model for one epoch on the training dataset. Parameters: train_loader (torch.utils.data.DataLoader): The trainset dataloader model (torch.nn.module): Model to be trained criterion (torch.nn.criterion): Loss function epoch (int): Continuous epoch counter optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam lr_scheduler (Training.LearningRateScheduler): class implementing learning rate schedules device (str): device name where data is transferred to batch_split_size (int): size of smaller split of batch to calculate sequentially if too little memory is available args (dict): Dictionary of (command line) arguments. Needs to contain print_freq (int) and batch_size (int). """ batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to train mode model.train() end = time.time() factor = args.batch_size // batch_split_size last_batch = int( math.ceil(len(train_loader.dataset) / float(batch_split_size))) optimizer.zero_grad() for i, (inp, target) in enumerate(train_loader): inp, target = inp.to(device), target.to(device) # measure data loading time data_time.update(time.time() - end) # adjust the learning rate if i % factor == 0: lr_scheduler.adjust_learning_rate(optimizer, i // factor + 1) # compute output output = model(inp) loss = criterion(output, target) * inp.size(0) / float(args.batch_size) # measure accuracy and record loss prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item() * float(args.batch_size) / inp.size(0), inp.size(0)) top1.update(prec1.item(), inp.size(0)) top5.update(prec5.item(), inp.size(0)) # compute gradient and do SGD step loss.backward() if (i + 1) % factor == 0 or i == (last_batch - 1): optimizer.step() optimizer.zero_grad() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % (args.print_freq * factor) == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch + 1, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) print(' * Train: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format( top1=top1, top5=top5)) return top1.avg
def train(Dataset, model, criterion, epoch, iteration, optimizer, writer, device, args, save_path): """ Trains/updates the model for one epoch on the training dataset. Parameters: Dataset (torch.utils.data.Dataset): The dataset model (torch.nn.module): Model to be trained criterion (torch.nn.criterion): Loss function epoch (int): Continuous epoch counter optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam writer (tensorboard.SummaryWriter): TensorBoard writer instance device (str): device name where data is transferred to args (dict): Dictionary of (command line) arguments. Needs to contain print_freq (int) and log_weights (bool). """ # Create instances to accumulate losses etc. class_losses = AverageMeter() recon_losses = AverageMeter() kld_losses = AverageMeter() losses = AverageMeter() lwf_losses = AverageMeter() si_losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() top1 = AverageMeter() # switch to train mode model.train() if args.use_si and not args.is_multiheaded: if not model.module.temp_classifier_weights is None: # load unconsolidated classifier weights un_consolidate_classifier(model.module) #print("SI: loaded unconsolidated classifier weights") #print("requires grad check: ", model.module.classifier[-1].weight) end = time.time() # train if args.is_multiheaded and args.train_incremental_upper_bound: for trainset_index in range(len(Dataset.mh_trainsets)): # get temporary trainset loader according to trainset_index train_loader = torch.utils.data.DataLoader( Dataset.mh_trainsets[trainset_index], batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=torch.cuda.is_available(), drop_last=True) for i, (inp, target) in enumerate(train_loader): if args.is_multiheaded and not args.incremental_instance: # multiheaded incremental classes: move targets to head space target = target.clone() for i in range(target.size(0)): target[i] = Dataset.maps_target_head[trainset_index][ target.numpy()[i]] # move data to device inp = inp.to(device) target = target.to(device) if epoch % args.epochs == 0 and i == 0: visualize_image_grid(inp, writer, epoch + 1, 'train_inp_snapshot', save_path) recon_target = inp class_target = target # this needs to be below the line where the reconstruction target is set # sample and add noise to the input (but not to the target!). if args.denoising_noise_value > 0.0 and not args.no_vae: noise = torch.randn( inp.size()).to(device) * args.denoising_noise_value inp = inp + noise # measure data loading time data_time.update(time.time() - end) # compute model forward class_samples, recon_samples, mu, std = model(inp) # if we have an autoregressive model variant, further calculate the corresponding layers. if args.autoregression: recon_samples_autoregression = torch.zeros( recon_samples.size(0), inp.size(0), 256, inp.size(1), inp.size(2), inp.size(3)).to(device) for j in range(model.module.num_samples): recon_samples_autoregression[ j] = model.module.pixelcnn( recon_target, torch.sigmoid(recon_samples[j])).contiguous() recon_samples = recon_samples_autoregression # set the target to work with the 256-way Softmax recon_target = (recon_target * 255).long() # computer loss for respective head if not args.incremental_instance: head_start = (trainset_index) * args.num_increment_tasks head_end = (trainset_index + 1) * args.num_increment_tasks else: head_start = (trainset_index) * Dataset.num_classes head_end = (trainset_index + 1) * Dataset.num_classes if args.is_multiheaded: if not args.incremental_instance: class_loss, recon_loss, kld_loss = criterion( class_samples[:, :, head_start:head_end], class_target, recon_samples, recon_target, mu, std, device, args) else: class_loss, recon_loss, kld_loss = criterion( class_samples[:, :, head_start:head_end], class_target, recon_samples, recon_target, mu, std, device, args) else: if not args.is_segmentation: class_loss, recon_loss, kld_loss = criterion( class_samples, class_target, recon_samples, recon_target, mu, std, device, args) else: class_loss, recon_loss, kld_loss = criterion( class_samples, class_target, recon_samples, recon_target, mu, std, device, args, weight=Dataset.class_pixel_weight) # add the individual loss components together and weight the KL term. if args.no_vae: loss = class_loss else: loss = class_loss + recon_loss + args.var_beta * kld_loss # take mean to compute accuracy. Note if variational samples are 1 this only gets rid of a dummy dimension. if args.is_multiheaded: if not args.incremental_instance: output = torch.mean(class_samples[:, :, head_start:head_end], dim=0) else: output = torch.mean(class_samples[:, :, head_start:head_end], dim=0) else: output = torch.mean(class_samples, dim=0) # record precision/accuracy and losses prec1 = accuracy(output, target)[0] top1.update(prec1.item(), inp.size(0)) if args.no_vae: losses.update(class_loss.item(), inp.size(0)) class_losses.update(class_loss.item(), inp.size(0)) else: losses.update((class_loss + recon_loss + kld_loss).item(), inp.size(0)) class_losses.update(class_loss.item(), inp.size(0)) recon_losses.update(recon_loss.item(), inp.size(0)) kld_losses.update(kld_loss.item(), inp.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() # print progress if i % args.print_freq == 0: if args.use_lwf and model.module.prev_model: print( 'Training: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t' 'LwF Loss {lwf_loss:.4f}\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t' 'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'. format(epoch, i, len(Dataset.train_loader), batch_time=batch_time, data_time=data_time, loss=losses, cl_loss=class_losses, top1=top1, recon_loss=recon_losses, KLD_loss=kld_losses, lwf_loss=cl_lwf.item())) if args.use_si and model.module.si_storage.is_initialized: print( 'Training: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t' 'SI Loss {si_loss:.4f}\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t' 'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'. format(epoch, i, len(Dataset.train_loader), batch_time=batch_time, data_time=data_time, loss=losses, cl_loss=class_losses, top1=top1, recon_loss=recon_losses, KLD_loss=kld_losses, si_loss=loss_si.item())) else: print( 'Training: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t' 'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'. format(epoch, i, len(Dataset.train_loader), batch_time=batch_time, data_time=data_time, loss=losses, cl_loss=class_losses, top1=top1, recon_loss=recon_losses, KLD_loss=kld_losses)) # increase iteration iteration[0] += 1 else: for i, (inp, target) in enumerate(Dataset.train_loader): if args.is_multiheaded and not args.incremental_instance: # multiheaded incremental classes: move targets to head space target = target.clone() for i in range(target.size(0)): target[i] = Dataset.maps_target_head[-1][target.numpy()[i]] # move data to device inp = inp.to(device) target = target.to(device) #print("inp:", inp.shape) #print("target:", target.shape) if epoch % args.epochs == 0 and i == 0: visualize_image_grid(inp, writer, epoch + 1, 'train_inp_snapshot', save_path) if not args.is_segmentation: recon_target = inp else: # Split target to one hot encoding target_one_hot = to_one_hot(target.clone(), model.module.num_classes) # concat input and one_hot_target recon_target = torch.cat([inp, target_one_hot], dim=1) class_target = target # this needs to be below the line where the reconstruction target is set # sample and add noise to the input (but not to the target!). if args.denoising_noise_value > 0.0 and not args.no_vae: noise = torch.randn( inp.size()).to(device) * args.denoising_noise_value inp = inp + noise # measure data loading time data_time.update(time.time() - end) # compute model forward class_samples, recon_samples, mu, std = model(inp) # if we have an autoregressive model variant, further calculate the corresponding layers. if args.autoregression: recon_samples_autoregression = torch.zeros( recon_samples.size(0), inp.size(0), 256, inp.size(1), inp.size(2), inp.size(3)).to(device) for j in range(model.module.num_samples): recon_samples_autoregression[j] = model.module.pixelcnn( recon_target, torch.sigmoid(recon_samples[j])).contiguous() recon_samples = recon_samples_autoregression # set the target to work with the 256-way Softmax recon_target = (recon_target * 255).long() if args.is_multiheaded: if not args.incremental_instance: class_loss, recon_loss, kld_loss = criterion( class_samples[:, :, -args.num_increment_tasks:], class_target, recon_samples, recon_target, mu, std, device, args) else: class_loss, recon_loss, kld_loss = criterion( class_samples[:, :, -Dataset.num_classes:], class_target, recon_samples, recon_target, mu, std, device, args) else: if not args.is_segmentation: class_loss, recon_loss, kld_loss = criterion( class_samples, class_target, recon_samples, recon_target, mu, std, device, args) else: class_loss, recon_loss, kld_loss = criterion( class_samples, class_target, recon_samples, recon_target, mu, std, device, args, weight=Dataset.class_pixel_weight) # add the individual loss components together and weight the KL term. if args.no_vae: loss = class_loss else: loss = class_loss + recon_loss + args.var_beta * kld_loss # calculate lwf loss (if there is a previous model stored) if args.use_lwf and model.module.prev_model: # get prediction from previous model with torch.no_grad(): prev_pred_class_samples, _, _, _ = model.module.prev_model( inp) prev_cl_losses = torch.zeros( prev_pred_class_samples.size(0)).to(device) # loop through each sample for each input and calculate the correspond loss. Normalize the losses. for s in range(prev_pred_class_samples.size(0)): if args.is_multiheaded: if not args.incremental_instance: prev_cl_losses[s] = loss_fn_kd_multihead( class_samples[s] [:, :-args.num_increment_tasks], prev_pred_class_samples[s], task_sizes=args.num_increment_tasks) else: prev_cl_losses[s] = loss_fn_kd_multihead( class_samples[s][:, :-Dataset.num_classes], prev_pred_class_samples[s], task_sizes=Dataset.num_classes) else: if not args.is_segmentation: prev_cl_losses[s] = loss_fn_kd( class_samples[s], prev_pred_class_samples[s] ) #/ torch.numel(target) else: prev_cl_losses[s] = loss_fn_kd_2d( class_samples[s], prev_pred_class_samples[s] ) #/ torch.numel(target) # average the loss over all samples per input cl_lwf = torch.mean(prev_cl_losses, dim=0) # add lwf loss to overall loss loss += args.lmda * cl_lwf # record lwf losses lwf_losses.update(cl_lwf.item(), inp.size(0)) # calculate SI loss (if SI is initialized) if args.use_si and model.module.si_storage.is_initialized: if not args.is_segmentation: loss_si = args.lmda * ( SI.surrogate_loss(model.module.encoder, model.module.si_storage) + SI.surrogate_loss(model.module.latent_mu, model.module.si_storage_mu) + SI.surrogate_loss(model.module.latent_std, model.module.si_storage_std)) else: loss_si = args.lmda * ( SI.surrogate_loss(model.module.encoder, model.module.si_storage) + SI.surrogate_loss(model.module.bottleneck, model.module.si_storage_btn) + SI.surrogate_loss(model.module.decoder, model.module.si_storage_dec)) #print("SI_Loss:", loss_si) loss += loss_si si_losses.update(loss_si.item(), inp.size(0)) # take mean to compute accuracy. Note if variational samples are 1 this only gets rid of a dummy dimension. if args.is_multiheaded: if not args.incremental_instance: output = torch.mean( class_samples[:, :, -args.num_increment_tasks:], dim=0) else: output = torch.mean(class_samples[:, :, -Dataset.num_classes:], dim=0) else: output = torch.mean(class_samples, dim=0) # record precision/accuracy and losses if not args.is_segmentation: prec1 = accuracy(output, target)[0] top1.update(prec1.item(), inp.size(0)) else: ious_cc = iou_class_condtitional(pred=output.clone(), target=target.clone()) #print("iou", ious_cc) prec1 = iou_to_accuracy(ious_cc) #print("prec1", prec1) top1.update(prec1.item(), inp.size(0)) if args.no_vae: losses.update(class_loss.item(), inp.size(0)) class_losses.update(class_loss.item(), inp.size(0)) else: losses.update((class_loss + recon_loss + kld_loss).item(), inp.size(0)) class_losses.update(class_loss.item(), inp.size(0)) recon_losses.update(recon_loss.item(), inp.size(0)) kld_losses.update(kld_loss.item(), inp.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # SI: update running si paramters if args.use_si: if not args.is_segmentation: SI.update_si_parameters(model.module.encoder, model.module.si_storage) SI.update_si_parameters(model.module.latent_mu, model.module.si_storage_mu) SI.update_si_parameters(model.module.latent_std, model.module.si_storage_std) else: SI.update_si_parameters(model.module.encoder, model.module.si_storage) SI.update_si_parameters(model.module.bottleneck, model.module.si_storage_btn) SI.update_si_parameters(model.module.decoder, model.module.si_storage_dec) #print("SI: Updated running parameters") # measure elapsed time batch_time.update(time.time() - end) end = time.time() # print progress if i % args.print_freq == 0: if args.use_lwf and model.module.prev_model: print( 'Training: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t' 'LwF Loss {lwf_loss:.4f}\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t' 'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format( epoch, i, len(Dataset.train_loader), batch_time=batch_time, data_time=data_time, loss=losses, cl_loss=class_losses, top1=top1, recon_loss=recon_losses, KLD_loss=kld_losses, lwf_loss=cl_lwf.item())) elif args.use_si and model.module.si_storage.is_initialized: print( 'Training: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t' 'SI Loss {si_loss:.4f}\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t' 'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format( epoch, i, len(Dataset.train_loader), batch_time=batch_time, data_time=data_time, loss=losses, cl_loss=class_losses, top1=top1, recon_loss=recon_losses, KLD_loss=kld_losses, si_loss=loss_si.item())) else: print( 'Training: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Class Loss {cl_loss.val:.4f} ({cl_loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Recon Loss {recon_loss.val:.4f} ({recon_loss.avg:.4f})\t' 'KL {KLD_loss.val:.4f} ({KLD_loss.avg:.4f})'.format( epoch, i, len(Dataset.train_loader), batch_time=batch_time, data_time=data_time, loss=losses, cl_loss=class_losses, top1=top1, recon_loss=recon_losses, KLD_loss=kld_losses)) # increase iteration iteration[0] += 1 # TensorBoard summary logging writer.add_scalar('training/train_precision@1', top1.avg, epoch) writer.add_scalar('training/train_average_loss', losses.avg, epoch) writer.add_scalar('training/train_KLD', kld_losses.avg, epoch) writer.add_scalar('training/train_class_loss', class_losses.avg, epoch) writer.add_scalar('training/train_recon_loss', recon_losses.avg, epoch) writer.add_scalar('training/train_precision_itr@1', top1.avg, iteration[0]) writer.add_scalar('training/train_average_loss_itr', losses.avg, iteration[0]) writer.add_scalar('training/train_KLD_itr', kld_losses.avg, iteration[0]) writer.add_scalar('training/train_class_loss_itr', class_losses.avg, iteration[0]) writer.add_scalar('training/train_recon_loss_itr', recon_losses.avg, iteration[0]) if args.use_lwf: writer.add_scalar('training/train_lwf_loss', lwf_losses.avg, iteration[0]) if args.use_si: writer.add_scalar('training/train_si_loss', si_losses.avg, iteration[0]) # If the log weights argument is specified also add parameter and gradient histograms to TensorBoard. if args.log_weights: # Histograms and distributions of network parameters for tag, value in model.named_parameters(): tag = tag.replace('.', '/') writer.add_histogram(tag, value.data.cpu().numpy(), epoch, bins="auto") # second check required for buffers that appear in the parameters dict but don't receive gradients if value.requires_grad and value.grad is not None: writer.add_histogram(tag + '/grad', value.grad.data.cpu().numpy(), epoch, bins="auto") print(' * Train: Loss {loss.avg:.5f} Prec@1 {top1.avg:.3f}'.format( loss=losses, top1=top1))