示例#1
0
def reinit_weights(results, net, net_init, test_loader, criterion, epoch, device):
  """For each convolutional layer in net, reinitialize its weights
     with the values of the corresponding weights in net_init.
     
     Store the test accuracy of the resulting model in results.
     
     results = {
                 "epoch_id": [acc1, ..., a_l, ..., acc_L],
               } where a_l is the test accuracy of net with layer l replaced by the
                 corresponding layer in net_init.
  """
  logger = logging.getLogger('reinit_weights')
  if results is None:
    results = { "init_from_epoch" : {} }
    
  test_accs = []
  net.eval()
  net_init.eval()
  
  block_id, stage_id = 0, 0
  conv_counter = 0
  init_type = "random initialization." if str(epoch) == "rand" else "weights of epoch {}.".format(epoch)
  for layer_id, (layer, layer_init) in enumerate(zip(net.features, net_init.features)):
    if isinstance(layer, nn.Conv2d):
      if layer.kernel_size == (1,1):
        continue
      
      conv_counter +=1
      # replace layer
      logger.info("Reinitializing parameters of layer {} from {}".format(conv_counter, init_type))  
      weight_copy = layer.weight.clone()
      if layer.bias is not None:
        bias_copy = layer.bias.clone()
        layer.bias.data.copy_(layer_init.bias.data)
      layer.weight.data.copy_(layer_init.weight.data)
      
      _, top1_acc, _ = scores.evaluate(net, test_loader, criterion, device, topk=(1,5))
      test_accs.append(top1_acc)
      
      # restore original parameter
      layer.weight.data.copy_(weight_copy.data)
      if layer.bias is not None:
        layer.bias.data.copy_(bias_copy.data)
            
      block_id += 1
    elif isinstance(layer, nn.MaxPool2d):
      stage_id += 1
      block_id = 0
  try:
    _ = results["init_from_epoch"]
  except KeyError:
    results["init_from_epoch"] = {}
    
  results["init_from_epoch"][str(epoch)] = test_accs
  return results
示例#2
0
def train(model,
          end_epoch,
          train_loader,
          optimizer,
          criterion,
          scheduler,
          device,
          snapshot_dirname,
          start_epoch=0,
          snapshot_every=0,
          val_loader=None,
          kill_plateaus=False,
          best_acc1=0,
          writer=None,
          snapshot_all_until=0,
          filename='net',
          train_acc=False):
    """Train the specified model according to user options.
  
    Args:
    
    model (nn.Module) -- the model to be trained
    end_epoch (int) -- maximum number of epochs
    train_loader (nn.DataLoader) -- train set loader
    optimizer (torch.optim optimizer) -- the optimizer to use
    criterion -- loss function to use
    scheduler -- learning rate scheduler
    device (torch.device) -- device to use
    start_epoch (int) -- starting epoch (useful for resuming training)
    snapshot_every (int) -- frequency of snapshots (in epochs)
    test_loader (optional, nn.DataLoader) -- test set loader
    train_acc (bool) -- whether to report accuracy on the train set
    
  """
    converged = True  # used to kill models that plateau
    top1_prec = 0.
    if snapshot_every < 1:
        snapshot_every = end_epoch

    start_loss = 0.
    for epoch in range(start_epoch, end_epoch):
        # training loss
        avg_loss = 0.
        epoch_loss = 0.
        for batch_idx, (x, target) in enumerate(train_loader):
            optimizer.zero_grad()
            x, target = x.to(device,
                             non_blocking=True), target.to(device,
                                                           non_blocking=True)
            out = model(x)
            loss = criterion(out, target)
            avg_loss = avg_loss * 0.99 + loss.item() * 0.01
            epoch_loss += loss.item()
            loss.backward()
            optimizer.step()

            # report training loss
            if ((batch_idx + 1) % 100 == 0) or ((batch_idx + 1)
                                                == len(train_loader)):
                utils.print_train_loss(epoch, avg_loss, batch_idx,
                                       len(train_loader))
        # report training loss over epoch
        epoch_loss /= len(train_loader)
        utils.print_train_loss_epoch(epoch, epoch_loss)
        writer.add_scalar('Loss/train', avg_loss, epoch)

        if scheduler is not None:
            scheduler.step()
            writer.add_scalar('Lr', scheduler.get_lr()[0], epoch)

        if (epoch < snapshot_all_until) or (
            (epoch + 1) % snapshot_every == 0) or ((epoch + 1) == end_epoch):
            top1_acc, top5_acc = 0, 0
            if val_loader is not None:
                val_loss, top1_acc, top5_acc = scores.evaluate(model,
                                                               val_loader,
                                                               criterion,
                                                               device,
                                                               topk=(1, 5))
                utils.print_val_loss(epoch, val_loss, top1_acc, top5_acc)
                model = model.train()
                writer.add_scalar('Loss/val', val_loss, epoch)
                writer.add_scalar('Accuracy/val/top1', top1_acc, epoch)
                writer.add_scalar('Accuracy/val/top5', top5_acc, epoch)

                # check whether training is stalling
                if kill_plateaus:
                    if top1_prec == top1_acc:
                        logger.debug(
                            "Prec val accuracy: {}, current val accuracy: {}. Model unlikely to converge. Quitting."
                            .format(top1_prec, top1_acc))
                        converged = False
                        return model, converged
                    else:
                        top1_prec = top1_acc

            if train_acc:
                train_loss, top1_train, top5_train = scores.evaluate(
                    model, train_loader, criterion, device, topk=(1, 5))
                utils.print_train_loss_epoch(epoch, train_loss, top1_train,
                                             top5_train)
                model = model.train()
                writer.add_scalar('Accuracy/train/top1', top1_train, epoch)
                writer.add_scalar('Accuracy/train/top5', top5_train, epoch)
            # save snapshot
            snapshot.save_snapshot(model, optimizer, scheduler, epoch,
                                   top1_acc, top5_acc, filename,
                                   snapshot_dirname)
    return model, converged
示例#3
0
def main(args):
  """
    Main
  """
  if args.arch is None:
    print("Available architectures:")
    print(models.__all__)
    sys.exit(0)
    
  if args.dataset is None:
    print("Available datasets:")
    print(data_loader.__all__)
    sys.exit(0)
    
  # set manual seed if required
  if args.seed is not None:
    torch.manual_seed(args.seed)
    
  results_path, plots_path = prepare_dirs(args)
  logger = logging.getLogger('reinit_weights')
    
  if torch.cuda.is_available() and args.cuda:
    device = torch.device("cuda:0")
    cudnn.benchmark = True
  else:
    device = torch.device("cpu")
  
  classes, in_channels = data_loader.num_classes(args.dataset)
  if args.subsample_classes > 0:
    classes = args.subsample_classes
 
  if os.path.exists(args.load_from):
    logger.info("Loading {} from {}.".format(args.arch, args.load_from))
    net = snapshot.load_model(args.arch, classes, args.load_from, device, in_channels)
  else:
    logger.info("Cannot load trained model from {}: no such file.".format(args.load_from))
    sys.exit(1)
      
  net = net.to(device)
  criterion = nn.CrossEntropyLoss().to(device)
  
  # load test set
  _, test_loader, _ = data_loader.load_dataset(args.dataset, args.data_path, args.batch_size, shuffle=True,
                                  augmentation=False, num_workers=args.workers, nclasses=args.subsample_classes,
                                  class_sample_seed=args.class_sample_seed, upscale=args.upscale, upscale_padding=args.upscale_padding)
  # evaluate model
  logger.info("Evaluating trained model on test set.")
  test_loss, top1_acc, top5_acc = scores.evaluate(net, test_loader, criterion, device, topk=(1,5))
  utils.print_val_loss(0, test_loss, top1_acc, top5_acc, 'reinit_weights')
  
  results={}
  if os.path.exists(args.inits_from):
    logger.info("Loading network weights initializations from {}.".format(args.inits_from))
    # get generator
    with open(args.inits_from, 'r') as fp:
      for init_file in init_filenames(fp):
        if os.path.exists(init_file):
          logger.info("Loading network weights from {}".format(init_file))
          net_init = snapshot.load_model(args.arch, classes, init_file, device, in_channels)
        else:
          logger.warning("Warning. File not found: {}. Skipping.".format(init_file))
          continue
      
        splits = os.path.splitext(init_file)[0].split('_')
        if 'init' in splits:
          epoch = 0
        else:
          epoch = int(splits[-1]) +1
        results = reinit_weights(results, net, net_init, test_loader, criterion, epoch, device)
  
  if args.rand:
    # load random initialization
    logger.info("Loading random initialization.")
    random_init = models.load_model(args.arch, classes, pretrained=False, in_channels=in_channels)
    # randomize weights and compute accuracy
    results = reinit_weights(results, net, random_init, test_loader, criterion, "rand", device)
  
  if os.path.exists(args.load_from):
    filename = os.path.splitext(os.path.basename(args.load_from))[0]
    filename = 'reinit_' + filename
  
  results = write_metadata(results, args.load_from, args, top1_acc, top5_acc)
  
  # save results to json
  logger.info("Saving results to file.")
  snapshot.save_results(results, filename, results_path)
  
  # plot results
  logger.info("Plotting results.")
  plot_heatmaps(results, filename, plots_path)
示例#4
0
def main(args):
    # set up project directories
    tb_logdir, snapshot_dir = prepare_dirs(args)
    # get logger
    logger = logging.getLogger('train')
    # tensorboard writer
    writer = get_writer(args, tb_logdir)

    use_cuda = torch.cuda.is_available() and args.cuda

    # set manual seed if required
    if args.seed is not None:
        torch.manual_seed(args.seed)
        if use_cuda:
            torch.cuda.manual_seed_all(args.seed)

    # check for cuda supports
    if use_cuda:
        device = torch.device("cuda:0")
        cudnn.benchmark = True
    else:
        device = torch.device("cpu")

    # snapshot frequency
    if args.snapshot_every > 0 and not args.evaluate:
        logger.info('Saving snapshots to {}'.format(snapshot_dir))

    # load model
    classes, in_channels = data_loader.num_classes(args.dataset)
    if args.subsample_classes > 0:
        classes = args.subsample_classes
    net = models.load_model(args.arch,
                            classes=classes,
                            pretrained=args.pretrained,
                            in_channels=in_channels)

    if args.pretrained and args.resume_from == '':
        logger.info('Loading pretrained {} on ImageNet.'.format(args.arch))
    else:
        logger.info('Creating model {}.'.format(args.arch))

    if torch.cuda.device_count() > 1:
        logger.info("Running on {} GPUs".format(torch.cuda.device_count()))
        net.features = torch.nn.DataParallel(net.features)

    # move net to device
    net = net.to(device=device)

    # get data loader for the specified dataset
    train_loader, test_loader, val_loader = data_loader.load_dataset(
        args.dataset,
        args.data_path,
        args.batch_size,
        shuffle=args.shuffle,
        augmentation=args.augmentation,
        noise=args.noise,
        split=args.split,
        num_workers=args.workers,
        split_seed=args.split_seed,
        noise_seed=args.noise_seed,
        stratified=args.stratified,
        nclasses=args.subsample_classes,
        class_sample_seed=args.class_sample_seed,
        no_normalization=args.unnormalize,
        upscale=args.upscale,
        upscale_padding=args.upscale_padding)
    # define loss
    criterion = nn.CrossEntropyLoss().to(device)

    start_epoch = args.start_epoch
    best_acc1, best_acc5 = 0, 0
    # load model from file
    if os.path.isfile(args.resume_from):
        # resume training given state dictionary
        optimizer, scheduler = load_optimizer(args, net)
        try:
            net, optimizer, scheduler, start_epoch, best_acc1, best_acc5 = snapshot.load_snapshot(
                net, optimizer, scheduler, args.resume_from, device)
            if args.override:
                override_hyperparams(args, optimizer, scheduler)
        except KeyError:
            classes, in_channels = data_loader.num_classes(args.dataset)
            if args.subsample_classes > 0:
                classes = args.subsample_classes
            net = snapshot.load_model(args.arch, classes, args.resume_from,
                                      device, in_channels)

    else:
        # define optimizer
        optimizer, scheduler = load_optimizer(args, net)

    # evaluate model
    if args.evaluate:
        val_loss, top1_acc, top5_acc = scores.evaluate(net,
                                                       test_loader,
                                                       criterion,
                                                       device,
                                                       topk=(1, 5))
        utils.print_val_loss(args.epochs, val_loss, top1_acc, top5_acc)
        writer.add_scalar('Loss/test', val_loss, args.epochs)
        writer.add_scalar('Accuracy/test/top1', top1_acc, args.epochs)
        writer.add_scalar('Accuracy/test/top5', top5_acc, args.epochs)
        writer.close()
        return

    if args.evaluate_train:
        train_loss, top1_acc, top5_acc = scores.evaluate(net,
                                                         train_loader,
                                                         criterion,
                                                         device,
                                                         topk=(1, 5))
        utils.print_train_loss_epoch(args.epochs, train_loss, top1_acc,
                                     top5_acc)
        if best_acc1 * best_acc5 > 0:
            # if nonzero, print best val accuracy
            utils.print_val_loss(args.epochs, -1., best_acc1, best_acc5)
        writer.add_scalar('Loss/train', train_loss, args.epochs)
        writer.add_scalar('Accuracy/train/top1', top1_acc, args.epochs)
        writer.add_scalar('Accuracy/train/top5', top5_acc, args.epochs)
        writer.close()
        return

    if args.eval_regularization_loss:
        regularization_loss = scores.compute_regularization_loss(
            net, args.weight_decay)
        utils.print_regularization_loss_epoch(args.epochs, regularization_loss)
        writer.add_scalar('Regularization loss', regularization_loss,
                          args.epochs)
        writer.close()
        return

    utils.print_model_config(args)

    if start_epoch == 0:
        pretrained = 'pretrained_' if args.pretrained else 'init_'
        filename = args.arch + '_' + pretrained + str(start_epoch) + '.pt'
        logger.info("Saving model initialization to {}".format(filename))
        snapshot.save_model(net, filename, snapshot_dir)

    # train the model
    net.train()
    if val_loader is None and test_loader is not None:
        val_loader = test_loader
        logger.warning("Using TEST set to validate model during training!")
    net, converged = train(net,
                           args.epochs,
                           train_loader,
                           optimizer,
                           criterion,
                           scheduler,
                           device,
                           snapshot_dirname=snapshot_dir,
                           start_epoch=start_epoch,
                           snapshot_every=args.snapshot_every,
                           val_loader=val_loader,
                           kill_plateaus=args.kill_plateaus,
                           best_acc1=best_acc1,
                           writer=writer,
                           snapshot_all_until=args.snapshot_all_until,
                           filename=args.arch,
                           train_acc=args.train_acc)
    if test_loader is not None:
        val_loss, top1_acc, top5_acc = scores.evaluate(net,
                                                       test_loader,
                                                       criterion,
                                                       device,
                                                       topk=(1, 5))
        utils.print_val_loss(args.epochs, val_loss, top1_acc, top5_acc)
        net = net.train()
        writer.add_scalar('Loss/test', val_loss, args.epochs)
        writer.add_scalar('Accuracy/test/top1', top1_acc, args.epochs)
        writer.add_scalar('Accuracy/test/top5', top5_acc, args.epochs)

    # save final model
    if converged:
        pretrained = 'pretrained_' if args.pretrained else ''
        filename = args.arch + '_' + pretrained + str(args.epochs) + '.pt'
        snapshot.save_model(net, filename, snapshot_dir)

    writer.close()