Esempio n. 1
0
def train_test_net(run, user_options):
  """Train and save a network accoring to user options
  
  Args
  run (int): the current independent run (used in filenames)
  user_options (argparser) : user specified options
  """

  # get logger
  logging.getLogger('train')

  #initialize model
  net = models.model_factory(user_options.arch, dataset=user_options.dataset, init=user_options.init)
  
  if torch.cuda.device_count() > 1:
    logger.info("Running on {} GPUs".format(torch.cuda.device_count()))
    net = NamedDataParallel(net)
  
  # move net to device
  net = net.to(device=device)
  
  # get data loader for the specified dataset
  train_loader, test_loader = data_loaders.load_dataset(user_options.dataset, user_options.dataset_path, user_options.noisy, user_options.batch_size)

  # define loss
  criterion = load_criterion(user_options)
  criterion = criterion.to(device)
  
  # resume training from snapshot if specified
  start_epoch = 0
  if os.path.isfile(user_options.resume_from):
    # resume training given state dictionary
    optimizer, scheduler = load_optimizer(user_options, net)
    net, optimizer, scheduler, start_epoch = snapshot.load_snapshot(net, optimizer, scheduler, user_options.resume_from, device)
    start_epoch = start_epoch + 1
  else:
    # define optimizer
    optimizer, scheduler = load_optimizer(user_options, net)

  # print model configuration
  logger.info("Running trial {} of {}".format(run+1, user_options.runs))
  utils.print_model_config(user_options, start_epoch)
  
  if start_epoch == 0: 
    filename = net.__name__ + '_' + str(start_epoch) + '_' + str(user_options.init) + '.pt'
    logger.info("Saving model initialization to {}".format(filename))
    snapshot.save_model(net, filename, snapshot_dirname)

  # train the model
  net, converged = train(net, user_options.epochs, train_loader, optimizer, criterion, scheduler, device, start_epoch, snapshot_every = user_options.snapshot_every, test_loader = test_loader, kill_plateaus = user_options.kill_plateaus, init_scheme=user_options.init)
  
  if test_loader is not None:
    val_loss, accuracy = scores.test(net, test_loader, criterion, device)
    utils.print_val_loss(user_options.epochs, val_loss, accuracy)
    net = net.train()

  # save final model
  if converged:
    filename = net.__name__ + '_' + str(user_options.epochs) + '_' + user_options.init + '.pt'
    snapshot.save_model(net, filename, snapshot_dirname)
Esempio n. 2
0
def train(model, end_epoch, train_loader, optimizer, criterion, scheduler, device, start_epoch = 0, snapshot_every = 0, test_loader = None, kill_plateaus = False, init_scheme=None):
  """Train the specified model according to user options.
  
    Args:
    
    model (nn.Module object) -- the model to be trained
    end_epoch (int) -- maximum number of epochs
    train_loader (object, DataLoader) -- train set loader
    optimizer (torch.optim optimizer) -- the optimizer to use
    criterion -- loss function to use
    scheduler -- learning rate scheduler
    device (torch.device) -- device to use
    start_epoch (int) -- starting epoch (useful for resuming training)
    snapshot_every (int) -- frequency of snapshots (in epochs)
    test_loader (optional, DataLoader) -- test set loader
    
  """
  if snapshot_every < 1:
    snapshot_every = end_epoch
  start_loss = 0.
  converged = True
  for epoch in range(start_epoch, end_epoch):
    # training loss
    avg_loss = 0
    epoch_loss = 0
    for batch_idx, (x, target) in enumerate(train_loader):
      optimizer.zero_grad()
      x, target = x.to(device=device), target.to(device=device)
      out = model(x)
      loss = criterion(out, target)
      avg_loss = avg_loss * 0.99 + loss.item() * 0.01
      epoch_loss += loss.item()
      loss.backward()
      optimizer.step()
      with torch.no_grad():
        if kill_plateaus:
          if epoch == start_epoch and batch_idx == 99:
            start_loss = avg_loss
          if epoch == 19 and batch_idx == 99:
            if scores.loss_plateaus(start_loss, avg_loss):
              logger.debug("Start loss: {}, current loss: {}. Model unlikely to converge. Quitting.".format(start_loss, avg_loss))
              converged = False
              return model, converged
      # report training loss
      if ((batch_idx+1) % 100 == 0) or ((batch_idx+1) == len(train_loader)):
        utils.print_train_loss(epoch, avg_loss, batch_idx, len(train_loader))
    # report training loss over epoch
    epoch_loss /= len(train_loader)
    utils.print_train_loss_epoch(epoch, epoch_loss)
    if scheduler is not None:
      scheduler.step()
    if ((epoch +1) % snapshot_every == 0) or ((epoch +1) == end_epoch):
      if test_loader is not None:
        val_loss, accuracy = scores.test(model, test_loader, criterion, device)
        utils.print_val_loss(epoch, val_loss, accuracy)
        model = model.train()
      # save snapshot
      snapshot.save_snapshot(model, optimizer, scheduler, epoch, snapshot_dirname, init_scheme)
  return model, converged
Esempio n. 3
0
def main(args):
  """
    Main
  """
  if args.arch is None:
    print("Available architectures:")
    print(models.__all__)
    sys.exit(0)
    
  if args.dataset is None:
    print("Available datasets:")
    print(data_loader.__all__)
    sys.exit(0)
    
  # set manual seed if required
  if args.seed is not None:
    torch.manual_seed(args.seed)
    
  results_path, plots_path = prepare_dirs(args)
  logger = logging.getLogger('reinit_weights')
    
  if torch.cuda.is_available() and args.cuda:
    device = torch.device("cuda:0")
    cudnn.benchmark = True
  else:
    device = torch.device("cpu")
  
  classes, in_channels = data_loader.num_classes(args.dataset)
  if args.subsample_classes > 0:
    classes = args.subsample_classes
 
  if os.path.exists(args.load_from):
    logger.info("Loading {} from {}.".format(args.arch, args.load_from))
    net = snapshot.load_model(args.arch, classes, args.load_from, device, in_channels)
  else:
    logger.info("Cannot load trained model from {}: no such file.".format(args.load_from))
    sys.exit(1)
      
  net = net.to(device)
  criterion = nn.CrossEntropyLoss().to(device)
  
  # load test set
  _, test_loader, _ = data_loader.load_dataset(args.dataset, args.data_path, args.batch_size, shuffle=True,
                                  augmentation=False, num_workers=args.workers, nclasses=args.subsample_classes,
                                  class_sample_seed=args.class_sample_seed, upscale=args.upscale, upscale_padding=args.upscale_padding)
  # evaluate model
  logger.info("Evaluating trained model on test set.")
  test_loss, top1_acc, top5_acc = scores.evaluate(net, test_loader, criterion, device, topk=(1,5))
  utils.print_val_loss(0, test_loss, top1_acc, top5_acc, 'reinit_weights')
  
  results={}
  if os.path.exists(args.inits_from):
    logger.info("Loading network weights initializations from {}.".format(args.inits_from))
    # get generator
    with open(args.inits_from, 'r') as fp:
      for init_file in init_filenames(fp):
        if os.path.exists(init_file):
          logger.info("Loading network weights from {}".format(init_file))
          net_init = snapshot.load_model(args.arch, classes, init_file, device, in_channels)
        else:
          logger.warning("Warning. File not found: {}. Skipping.".format(init_file))
          continue
      
        splits = os.path.splitext(init_file)[0].split('_')
        if 'init' in splits:
          epoch = 0
        else:
          epoch = int(splits[-1]) +1
        results = reinit_weights(results, net, net_init, test_loader, criterion, epoch, device)
  
  if args.rand:
    # load random initialization
    logger.info("Loading random initialization.")
    random_init = models.load_model(args.arch, classes, pretrained=False, in_channels=in_channels)
    # randomize weights and compute accuracy
    results = reinit_weights(results, net, random_init, test_loader, criterion, "rand", device)
  
  if os.path.exists(args.load_from):
    filename = os.path.splitext(os.path.basename(args.load_from))[0]
    filename = 'reinit_' + filename
  
  results = write_metadata(results, args.load_from, args, top1_acc, top5_acc)
  
  # save results to json
  logger.info("Saving results to file.")
  snapshot.save_results(results, filename, results_path)
  
  # plot results
  logger.info("Plotting results.")
  plot_heatmaps(results, filename, plots_path)
Esempio n. 4
0
def student_train_test(user_options):
  """Train student network by knowledge distillation
  
  Args
  run (int): the current independent run (used filenames)
  user_options (argparser) : user specified options
  """
  # get logger
  logging.getLogger('train')

  # load teacher model
  teacher = models.model_factory(user_options.arch, dataset=user_options.dataset, init=user_options.init)
  
  if torch.cuda.device_count() > 1:
    logger.info("Running teacher network on {} GPUs".format(torch.cuda.device_count()))
    teacher = NamedDataParallel(teacher)
    tdevice = device
  else:
    tdevice = torch.device('cpu')
  
  # move net to device
  teacher = teacher.to(device=tdevice)
  
  # load teacher network from file
  if os.path.isfile(user_options.resume_from):
    teacher, _, _, _ = snapshot.load_snapshot(teacher, None, None, user_options.resume_from, tdevice)
    teacher = teacher.eval()
  else:
    raise ValueError('Missing teacher model definition. Specify it with --resume-from [FILENAME]')
  
  # get data loader for the specified dataset
  train_loader, test_loader = data_loaders.load_dataset(user_options.dataset, user_options.dataset_path, user_options.noisy, user_options.batch_size)
  
  # load student
  student = models.student_factory(user_options.arch, user_options.dataset, init=user_options.init)
  
  if torch.cuda.device_count() > 1:
    logger.info("Running student network on {} GPUs".format(torch.cuda.device_count()))
    student = NamedDataParallel(student)
    
  student = student.to(device=device)

  # load optimizer, scheduler
  optimizer, scheduler = load_optimizer(user_options, student)

  # define loss
  criterion = load_criterion(user_options)

  # print model configuration
  start_epoch = 0
  utils.print_student_config(user_options)
  
  # save model at initialization
  teacher_name = os.path.basename(user_options.resume_from)
  teacher_name = os.path.splitext(teacher_name)[0] # remove file extension
  teacher_name = teacher_name.split('_')[0]
  filename = 'Student_' + teacher_name + '_' + str(start_epoch) + '.pt'
  snapshot.save_model(student, filename, snapshot_dirname)

  # train the model
  student, converged = distill(student, teacher, user_options.epochs, train_loader, optimizer, criterion, scheduler, tdevice, device, start_epoch, snapshot_every = user_options.epochs, kill_plateaus = user_options.kill_plateaus)
  
  if test_loader is not None:
    test_criterion = nn.CrossEntropyLoss()
    val_loss, accuracy = scores.test(student, test_loader, test_criterion, device)
    utils.print_val_loss(user_options.epochs, val_loss, accuracy)

  # save final model
  if converged:
    teacher_name = os.path.basename(user_options.resume_from)
    teacher_name = os.path.splitext(teacher_name)[0] # remove file extension
    filename = 'Student_' + teacher_name + '.pt'
    snapshot.save_model(student, filename, snapshot_dirname)
Esempio n. 5
0
def train(model,
          end_epoch,
          train_loader,
          optimizer,
          criterion,
          scheduler,
          device,
          snapshot_dirname,
          start_epoch=0,
          snapshot_every=0,
          val_loader=None,
          kill_plateaus=False,
          best_acc1=0,
          writer=None,
          snapshot_all_until=0,
          filename='net',
          train_acc=False):
    """Train the specified model according to user options.
  
    Args:
    
    model (nn.Module) -- the model to be trained
    end_epoch (int) -- maximum number of epochs
    train_loader (nn.DataLoader) -- train set loader
    optimizer (torch.optim optimizer) -- the optimizer to use
    criterion -- loss function to use
    scheduler -- learning rate scheduler
    device (torch.device) -- device to use
    start_epoch (int) -- starting epoch (useful for resuming training)
    snapshot_every (int) -- frequency of snapshots (in epochs)
    test_loader (optional, nn.DataLoader) -- test set loader
    train_acc (bool) -- whether to report accuracy on the train set
    
  """
    converged = True  # used to kill models that plateau
    top1_prec = 0.
    if snapshot_every < 1:
        snapshot_every = end_epoch

    start_loss = 0.
    for epoch in range(start_epoch, end_epoch):
        # training loss
        avg_loss = 0.
        epoch_loss = 0.
        for batch_idx, (x, target) in enumerate(train_loader):
            optimizer.zero_grad()
            x, target = x.to(device,
                             non_blocking=True), target.to(device,
                                                           non_blocking=True)
            out = model(x)
            loss = criterion(out, target)
            avg_loss = avg_loss * 0.99 + loss.item() * 0.01
            epoch_loss += loss.item()
            loss.backward()
            optimizer.step()

            # report training loss
            if ((batch_idx + 1) % 100 == 0) or ((batch_idx + 1)
                                                == len(train_loader)):
                utils.print_train_loss(epoch, avg_loss, batch_idx,
                                       len(train_loader))
        # report training loss over epoch
        epoch_loss /= len(train_loader)
        utils.print_train_loss_epoch(epoch, epoch_loss)
        writer.add_scalar('Loss/train', avg_loss, epoch)

        if scheduler is not None:
            scheduler.step()
            writer.add_scalar('Lr', scheduler.get_lr()[0], epoch)

        if (epoch < snapshot_all_until) or (
            (epoch + 1) % snapshot_every == 0) or ((epoch + 1) == end_epoch):
            top1_acc, top5_acc = 0, 0
            if val_loader is not None:
                val_loss, top1_acc, top5_acc = scores.evaluate(model,
                                                               val_loader,
                                                               criterion,
                                                               device,
                                                               topk=(1, 5))
                utils.print_val_loss(epoch, val_loss, top1_acc, top5_acc)
                model = model.train()
                writer.add_scalar('Loss/val', val_loss, epoch)
                writer.add_scalar('Accuracy/val/top1', top1_acc, epoch)
                writer.add_scalar('Accuracy/val/top5', top5_acc, epoch)

                # check whether training is stalling
                if kill_plateaus:
                    if top1_prec == top1_acc:
                        logger.debug(
                            "Prec val accuracy: {}, current val accuracy: {}. Model unlikely to converge. Quitting."
                            .format(top1_prec, top1_acc))
                        converged = False
                        return model, converged
                    else:
                        top1_prec = top1_acc

            if train_acc:
                train_loss, top1_train, top5_train = scores.evaluate(
                    model, train_loader, criterion, device, topk=(1, 5))
                utils.print_train_loss_epoch(epoch, train_loss, top1_train,
                                             top5_train)
                model = model.train()
                writer.add_scalar('Accuracy/train/top1', top1_train, epoch)
                writer.add_scalar('Accuracy/train/top5', top5_train, epoch)
            # save snapshot
            snapshot.save_snapshot(model, optimizer, scheduler, epoch,
                                   top1_acc, top5_acc, filename,
                                   snapshot_dirname)
    return model, converged
Esempio n. 6
0
def main(args):
    # set up project directories
    tb_logdir, snapshot_dir = prepare_dirs(args)
    # get logger
    logger = logging.getLogger('train')
    # tensorboard writer
    writer = get_writer(args, tb_logdir)

    use_cuda = torch.cuda.is_available() and args.cuda

    # set manual seed if required
    if args.seed is not None:
        torch.manual_seed(args.seed)
        if use_cuda:
            torch.cuda.manual_seed_all(args.seed)

    # check for cuda supports
    if use_cuda:
        device = torch.device("cuda:0")
        cudnn.benchmark = True
    else:
        device = torch.device("cpu")

    # snapshot frequency
    if args.snapshot_every > 0 and not args.evaluate:
        logger.info('Saving snapshots to {}'.format(snapshot_dir))

    # load model
    classes, in_channels = data_loader.num_classes(args.dataset)
    if args.subsample_classes > 0:
        classes = args.subsample_classes
    net = models.load_model(args.arch,
                            classes=classes,
                            pretrained=args.pretrained,
                            in_channels=in_channels)

    if args.pretrained and args.resume_from == '':
        logger.info('Loading pretrained {} on ImageNet.'.format(args.arch))
    else:
        logger.info('Creating model {}.'.format(args.arch))

    if torch.cuda.device_count() > 1:
        logger.info("Running on {} GPUs".format(torch.cuda.device_count()))
        net.features = torch.nn.DataParallel(net.features)

    # move net to device
    net = net.to(device=device)

    # get data loader for the specified dataset
    train_loader, test_loader, val_loader = data_loader.load_dataset(
        args.dataset,
        args.data_path,
        args.batch_size,
        shuffle=args.shuffle,
        augmentation=args.augmentation,
        noise=args.noise,
        split=args.split,
        num_workers=args.workers,
        split_seed=args.split_seed,
        noise_seed=args.noise_seed,
        stratified=args.stratified,
        nclasses=args.subsample_classes,
        class_sample_seed=args.class_sample_seed,
        no_normalization=args.unnormalize,
        upscale=args.upscale,
        upscale_padding=args.upscale_padding)
    # define loss
    criterion = nn.CrossEntropyLoss().to(device)

    start_epoch = args.start_epoch
    best_acc1, best_acc5 = 0, 0
    # load model from file
    if os.path.isfile(args.resume_from):
        # resume training given state dictionary
        optimizer, scheduler = load_optimizer(args, net)
        try:
            net, optimizer, scheduler, start_epoch, best_acc1, best_acc5 = snapshot.load_snapshot(
                net, optimizer, scheduler, args.resume_from, device)
            if args.override:
                override_hyperparams(args, optimizer, scheduler)
        except KeyError:
            classes, in_channels = data_loader.num_classes(args.dataset)
            if args.subsample_classes > 0:
                classes = args.subsample_classes
            net = snapshot.load_model(args.arch, classes, args.resume_from,
                                      device, in_channels)

    else:
        # define optimizer
        optimizer, scheduler = load_optimizer(args, net)

    # evaluate model
    if args.evaluate:
        val_loss, top1_acc, top5_acc = scores.evaluate(net,
                                                       test_loader,
                                                       criterion,
                                                       device,
                                                       topk=(1, 5))
        utils.print_val_loss(args.epochs, val_loss, top1_acc, top5_acc)
        writer.add_scalar('Loss/test', val_loss, args.epochs)
        writer.add_scalar('Accuracy/test/top1', top1_acc, args.epochs)
        writer.add_scalar('Accuracy/test/top5', top5_acc, args.epochs)
        writer.close()
        return

    if args.evaluate_train:
        train_loss, top1_acc, top5_acc = scores.evaluate(net,
                                                         train_loader,
                                                         criterion,
                                                         device,
                                                         topk=(1, 5))
        utils.print_train_loss_epoch(args.epochs, train_loss, top1_acc,
                                     top5_acc)
        if best_acc1 * best_acc5 > 0:
            # if nonzero, print best val accuracy
            utils.print_val_loss(args.epochs, -1., best_acc1, best_acc5)
        writer.add_scalar('Loss/train', train_loss, args.epochs)
        writer.add_scalar('Accuracy/train/top1', top1_acc, args.epochs)
        writer.add_scalar('Accuracy/train/top5', top5_acc, args.epochs)
        writer.close()
        return

    if args.eval_regularization_loss:
        regularization_loss = scores.compute_regularization_loss(
            net, args.weight_decay)
        utils.print_regularization_loss_epoch(args.epochs, regularization_loss)
        writer.add_scalar('Regularization loss', regularization_loss,
                          args.epochs)
        writer.close()
        return

    utils.print_model_config(args)

    if start_epoch == 0:
        pretrained = 'pretrained_' if args.pretrained else 'init_'
        filename = args.arch + '_' + pretrained + str(start_epoch) + '.pt'
        logger.info("Saving model initialization to {}".format(filename))
        snapshot.save_model(net, filename, snapshot_dir)

    # train the model
    net.train()
    if val_loader is None and test_loader is not None:
        val_loader = test_loader
        logger.warning("Using TEST set to validate model during training!")
    net, converged = train(net,
                           args.epochs,
                           train_loader,
                           optimizer,
                           criterion,
                           scheduler,
                           device,
                           snapshot_dirname=snapshot_dir,
                           start_epoch=start_epoch,
                           snapshot_every=args.snapshot_every,
                           val_loader=val_loader,
                           kill_plateaus=args.kill_plateaus,
                           best_acc1=best_acc1,
                           writer=writer,
                           snapshot_all_until=args.snapshot_all_until,
                           filename=args.arch,
                           train_acc=args.train_acc)
    if test_loader is not None:
        val_loss, top1_acc, top5_acc = scores.evaluate(net,
                                                       test_loader,
                                                       criterion,
                                                       device,
                                                       topk=(1, 5))
        utils.print_val_loss(args.epochs, val_loss, top1_acc, top5_acc)
        net = net.train()
        writer.add_scalar('Loss/test', val_loss, args.epochs)
        writer.add_scalar('Accuracy/test/top1', top1_acc, args.epochs)
        writer.add_scalar('Accuracy/test/top5', top5_acc, args.epochs)

    # save final model
    if converged:
        pretrained = 'pretrained_' if args.pretrained else ''
        filename = args.arch + '_' + pretrained + str(args.epochs) + '.pt'
        snapshot.save_model(net, filename, snapshot_dir)

    writer.close()