Exemple #1
0
def main(args):
    """
    Main
  """
    if args.arch is None:
        print("Available architectures:")
        print(models.__all__)
        sys.exit(0)

    if args.dataset is None:
        print("Available datasets:")
        print(data_loader.__all__)
        sys.exit(0)

    # set manual seed if required
    if args.seed is not None:
        torch.manual_seed(args.seed)

    results_path = prepare_dirs(args)
    logger = logging.getLogger('projections')

    if torch.cuda.is_available() and args.cuda:
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")

    classes, in_channels = data_loader.num_classes(args.dataset)
    if args.subsample_classes > 0:
        classes = args.subsample_classes
    if os.path.exists(args.load_from):
        logger.info("Loading {} from {}.".format(args.arch, args.load_from))
        net = snapshot.load_model(args.arch,
                                  classes,
                                  args.load_from,
                                  device,
                                  in_channels=in_channels)
        if args.pretrained:
            logger.warning(
                "Warning: --pretrained should only be used when loading a pretrained from the model zoo. Ignoring."
            )
    else:
        try:
            net = models.load_model(args.arch,
                                    classes,
                                    pretrained=args.pretrained,
                                    in_channels=in_channels)
        except ValueError:
            print("Unsupported architecture: {}".format(args.arch))
            print("Supported architectures:")
            print(models.__all__)
            sys.exit(0)

        if args.pretrained:
            logger.info(
                "Loading pretrained {} from torchvision model zoo".format(
                    args.arch))
        else:
            logger.info(
                "Could not find snapshot. Initializing network weights from scratch."
            )

    net = net.to(device)

    net_init = None
    if os.path.exists(args.init_from):
        logger.info(
            "Loading network weights at initialization from {}. They will be used to compute the network's distance from initialization."
            .format(args.init_from))
        net_init = snapshot.load_model(args.arch, classes, args.init_from,
                                       device, in_channels)

    # compute statistics per network
    logger.info("Computing projections statistics...")
    results = positive_orthant(net, args.normalize, args.with_linear)
    if net_init is not None:
        logger.info("Computing distance from initialization...")
        results = distance_from_init(net, net_init, results, args.with_linear)
    logger.info("done")

    if os.path.exists(args.load_from):
        filename = os.path.splitext(os.path.basename(args.load_from))[0]
    elif args.pretrained:
        filename = 'pretrained_' + args.arch
    else:
        filename = 'init_' + args.arch

    results = write_metadata(results, filename, args)

    # save results to json
    logger.info("Saving results to file.")
    snapshot.save_results(results, filename, results_path)
Exemple #2
0
def main(args):
  """
    Main
  """
  if args.arch is None:
    print("Available architectures:")
    print(models.__all__)
    sys.exit(0)
    
  if args.dataset is None:
    print("Available datasets:")
    print(data_loader.__all__)
    sys.exit(0)
    
  # set manual seed if required
  if args.seed is not None:
    torch.manual_seed(args.seed)
    
  results_path, plots_path = prepare_dirs(args)
  logger = logging.getLogger('reinit_weights')
    
  if torch.cuda.is_available() and args.cuda:
    device = torch.device("cuda:0")
    cudnn.benchmark = True
  else:
    device = torch.device("cpu")
  
  classes, in_channels = data_loader.num_classes(args.dataset)
  if args.subsample_classes > 0:
    classes = args.subsample_classes
 
  if os.path.exists(args.load_from):
    logger.info("Loading {} from {}.".format(args.arch, args.load_from))
    net = snapshot.load_model(args.arch, classes, args.load_from, device, in_channels)
  else:
    logger.info("Cannot load trained model from {}: no such file.".format(args.load_from))
    sys.exit(1)
      
  net = net.to(device)
  criterion = nn.CrossEntropyLoss().to(device)
  
  # load test set
  _, test_loader, _ = data_loader.load_dataset(args.dataset, args.data_path, args.batch_size, shuffle=True,
                                  augmentation=False, num_workers=args.workers, nclasses=args.subsample_classes,
                                  class_sample_seed=args.class_sample_seed, upscale=args.upscale, upscale_padding=args.upscale_padding)
  # evaluate model
  logger.info("Evaluating trained model on test set.")
  test_loss, top1_acc, top5_acc = scores.evaluate(net, test_loader, criterion, device, topk=(1,5))
  utils.print_val_loss(0, test_loss, top1_acc, top5_acc, 'reinit_weights')
  
  results={}
  if os.path.exists(args.inits_from):
    logger.info("Loading network weights initializations from {}.".format(args.inits_from))
    # get generator
    with open(args.inits_from, 'r') as fp:
      for init_file in init_filenames(fp):
        if os.path.exists(init_file):
          logger.info("Loading network weights from {}".format(init_file))
          net_init = snapshot.load_model(args.arch, classes, init_file, device, in_channels)
        else:
          logger.warning("Warning. File not found: {}. Skipping.".format(init_file))
          continue
      
        splits = os.path.splitext(init_file)[0].split('_')
        if 'init' in splits:
          epoch = 0
        else:
          epoch = int(splits[-1]) +1
        results = reinit_weights(results, net, net_init, test_loader, criterion, epoch, device)
  
  if args.rand:
    # load random initialization
    logger.info("Loading random initialization.")
    random_init = models.load_model(args.arch, classes, pretrained=False, in_channels=in_channels)
    # randomize weights and compute accuracy
    results = reinit_weights(results, net, random_init, test_loader, criterion, "rand", device)
  
  if os.path.exists(args.load_from):
    filename = os.path.splitext(os.path.basename(args.load_from))[0]
    filename = 'reinit_' + filename
  
  results = write_metadata(results, args.load_from, args, top1_acc, top5_acc)
  
  # save results to json
  logger.info("Saving results to file.")
  snapshot.save_results(results, filename, results_path)
  
  # plot results
  logger.info("Plotting results.")
  plot_heatmaps(results, filename, plots_path)
Exemple #3
0
def main(args):
    # set up project directories
    tb_logdir, snapshot_dir = prepare_dirs(args)
    # get logger
    logger = logging.getLogger('train')
    # tensorboard writer
    writer = get_writer(args, tb_logdir)

    use_cuda = torch.cuda.is_available() and args.cuda

    # set manual seed if required
    if args.seed is not None:
        torch.manual_seed(args.seed)
        if use_cuda:
            torch.cuda.manual_seed_all(args.seed)

    # check for cuda supports
    if use_cuda:
        device = torch.device("cuda:0")
        cudnn.benchmark = True
    else:
        device = torch.device("cpu")

    # snapshot frequency
    if args.snapshot_every > 0 and not args.evaluate:
        logger.info('Saving snapshots to {}'.format(snapshot_dir))

    # load model
    classes, in_channels = data_loader.num_classes(args.dataset)
    if args.subsample_classes > 0:
        classes = args.subsample_classes
    net = models.load_model(args.arch,
                            classes=classes,
                            pretrained=args.pretrained,
                            in_channels=in_channels)

    if args.pretrained and args.resume_from == '':
        logger.info('Loading pretrained {} on ImageNet.'.format(args.arch))
    else:
        logger.info('Creating model {}.'.format(args.arch))

    if torch.cuda.device_count() > 1:
        logger.info("Running on {} GPUs".format(torch.cuda.device_count()))
        net.features = torch.nn.DataParallel(net.features)

    # move net to device
    net = net.to(device=device)

    # get data loader for the specified dataset
    train_loader, test_loader, val_loader = data_loader.load_dataset(
        args.dataset,
        args.data_path,
        args.batch_size,
        shuffle=args.shuffle,
        augmentation=args.augmentation,
        noise=args.noise,
        split=args.split,
        num_workers=args.workers,
        split_seed=args.split_seed,
        noise_seed=args.noise_seed,
        stratified=args.stratified,
        nclasses=args.subsample_classes,
        class_sample_seed=args.class_sample_seed,
        no_normalization=args.unnormalize,
        upscale=args.upscale,
        upscale_padding=args.upscale_padding)
    # define loss
    criterion = nn.CrossEntropyLoss().to(device)

    start_epoch = args.start_epoch
    best_acc1, best_acc5 = 0, 0
    # load model from file
    if os.path.isfile(args.resume_from):
        # resume training given state dictionary
        optimizer, scheduler = load_optimizer(args, net)
        try:
            net, optimizer, scheduler, start_epoch, best_acc1, best_acc5 = snapshot.load_snapshot(
                net, optimizer, scheduler, args.resume_from, device)
            if args.override:
                override_hyperparams(args, optimizer, scheduler)
        except KeyError:
            classes, in_channels = data_loader.num_classes(args.dataset)
            if args.subsample_classes > 0:
                classes = args.subsample_classes
            net = snapshot.load_model(args.arch, classes, args.resume_from,
                                      device, in_channels)

    else:
        # define optimizer
        optimizer, scheduler = load_optimizer(args, net)

    # evaluate model
    if args.evaluate:
        val_loss, top1_acc, top5_acc = scores.evaluate(net,
                                                       test_loader,
                                                       criterion,
                                                       device,
                                                       topk=(1, 5))
        utils.print_val_loss(args.epochs, val_loss, top1_acc, top5_acc)
        writer.add_scalar('Loss/test', val_loss, args.epochs)
        writer.add_scalar('Accuracy/test/top1', top1_acc, args.epochs)
        writer.add_scalar('Accuracy/test/top5', top5_acc, args.epochs)
        writer.close()
        return

    if args.evaluate_train:
        train_loss, top1_acc, top5_acc = scores.evaluate(net,
                                                         train_loader,
                                                         criterion,
                                                         device,
                                                         topk=(1, 5))
        utils.print_train_loss_epoch(args.epochs, train_loss, top1_acc,
                                     top5_acc)
        if best_acc1 * best_acc5 > 0:
            # if nonzero, print best val accuracy
            utils.print_val_loss(args.epochs, -1., best_acc1, best_acc5)
        writer.add_scalar('Loss/train', train_loss, args.epochs)
        writer.add_scalar('Accuracy/train/top1', top1_acc, args.epochs)
        writer.add_scalar('Accuracy/train/top5', top5_acc, args.epochs)
        writer.close()
        return

    if args.eval_regularization_loss:
        regularization_loss = scores.compute_regularization_loss(
            net, args.weight_decay)
        utils.print_regularization_loss_epoch(args.epochs, regularization_loss)
        writer.add_scalar('Regularization loss', regularization_loss,
                          args.epochs)
        writer.close()
        return

    utils.print_model_config(args)

    if start_epoch == 0:
        pretrained = 'pretrained_' if args.pretrained else 'init_'
        filename = args.arch + '_' + pretrained + str(start_epoch) + '.pt'
        logger.info("Saving model initialization to {}".format(filename))
        snapshot.save_model(net, filename, snapshot_dir)

    # train the model
    net.train()
    if val_loader is None and test_loader is not None:
        val_loader = test_loader
        logger.warning("Using TEST set to validate model during training!")
    net, converged = train(net,
                           args.epochs,
                           train_loader,
                           optimizer,
                           criterion,
                           scheduler,
                           device,
                           snapshot_dirname=snapshot_dir,
                           start_epoch=start_epoch,
                           snapshot_every=args.snapshot_every,
                           val_loader=val_loader,
                           kill_plateaus=args.kill_plateaus,
                           best_acc1=best_acc1,
                           writer=writer,
                           snapshot_all_until=args.snapshot_all_until,
                           filename=args.arch,
                           train_acc=args.train_acc)
    if test_loader is not None:
        val_loss, top1_acc, top5_acc = scores.evaluate(net,
                                                       test_loader,
                                                       criterion,
                                                       device,
                                                       topk=(1, 5))
        utils.print_val_loss(args.epochs, val_loss, top1_acc, top5_acc)
        net = net.train()
        writer.add_scalar('Loss/test', val_loss, args.epochs)
        writer.add_scalar('Accuracy/test/top1', top1_acc, args.epochs)
        writer.add_scalar('Accuracy/test/top5', top5_acc, args.epochs)

    # save final model
    if converged:
        pretrained = 'pretrained_' if args.pretrained else ''
        filename = args.arch + '_' + pretrained + str(args.epochs) + '.pt'
        snapshot.save_model(net, filename, snapshot_dir)

    writer.close()