Exemplo n.º 1
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    number_of_classes = class_dict[args.dataset]
    in_channels = inp_channel_dict[args.dataset]
    model = Network(args.init_channels, number_of_classes, args.layers,
                    criterion, in_channels)
    model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # Get transforms to apply on data
    train_transform, valid_transform = utils.get_data_transforms(args)

    # Get the training queue
    train_queue, valid_queue = get_training_queues(args, train_transform)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    architect = Architect(model, args)

    for epoch in range(args.epochs):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)

        genotype = model.genotype()
        logging.info('genotype = %s', genotype)

        print(F.softmax(model.alphas_normal, dim=-1))
        print(F.softmax(model.alphas_reduce, dim=-1))

        # training
        train_acc, train_obj = train(train_queue, valid_queue, model,
                                     architect, criterion, optimizer, lr)
        logging.info('train_acc %f', train_acc)

        # validation
        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)

        utils.save(model, os.path.join(args.save, 'weights.pt'))
Exemplo n.º 2
0
def main():
  if not torch.cuda.is_available():
    logging.info('no gpu device available')
    sys.exit(1)

  np.random.seed(args.seed)
  torch.cuda.set_device(args.gpu)
  cudnn.benchmark = True
  torch.manual_seed(args.seed)
  cudnn.enabled=True
  torch.cuda.manual_seed(args.seed)
  logging.info('gpu device = %d' % args.gpu)
  logging.info("args = %s", args)

  genotype = eval("genotypes.%s" % args.arch)
  number_of_classes = class_dict[args.dataset]
  in_channels = inp_channel_dict[args.dataset]
  print(number_of_classes, in_channels)
  model = Network(args.init_channels, number_of_classes, args.layers, args.auxiliary, genotype, in_channels)
  model = model.cuda()

  logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

  criterion = nn.CrossEntropyLoss()
  criterion = criterion.cuda()
  optimizer = torch.optim.SGD(
      model.parameters(),
      args.learning_rate,
      momentum=args.momentum,
      weight_decay=args.weight_decay
      )

  train_transform, valid_transform = utils.get_data_transforms(args)
  train_queue, valid_queue = get_train_test_queues(args, train_transform, valid_transform)
  # train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
  # valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)

  # train_queue = torch.utils.data.DataLoader(
  #     train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2)

  # valid_queue = torch.utils.data.DataLoader(
  #     valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)

  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))

  for epoch in range(args.epochs):
    scheduler.step()
    logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
    model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

    train_acc, train_obj = train(train_queue, model, criterion, optimizer)
    logging.info('train_acc %f', train_acc)

    valid_acc, valid_obj = infer(valid_queue, model, criterion)
    logging.info('valid_acc %f', valid_acc)

    utils.save(model, os.path.join(args.save, 'weights.pt'))
Exemplo n.º 3
0
def predict_all(files, idx_to_class, model_path):
    device = get_device()
    model = models.resnet18()
    model.fc = nn.Linear(512, num_classes)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)

    data_transforms = get_data_transforms()['val']

    tmp = []

    for f in files:
        image = Image.open(f)
        image = data_transforms(image)
        image.unsqueeze_(dim=0)
        image = image.to(device)
        output = model(image)
        target_idx = torch.argmax(output).item()
        target_name = idx_to_class[target_idx]

        head, tail = os.path.split(f)
        tmp.append({'fname': tail, 'prediction': target_name})
    return tmp
Exemplo n.º 4
0
def main():
  parser = argparse.ArgumentParser("Common Argument Parser")
  parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
  parser.add_argument('--dataset', type=str, default='cifar10', help='which dataset:\
                      cifar10, mnist, emnist, fashion, svhn, stl10, devanagari')
  parser.add_argument('--batch_size', type=int, default=64, help='batch size')
  parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
  parser.add_argument('--learning_rate_min', type=float, default=1e-8, help='min learning rate')
  parser.add_argument('--lr_power_annealing_exponent_order', type=float, default=2,
                      help='Cosine Power Annealing Schedule Base, larger numbers make '
                           'the exponential more dominant, smaller make cosine more dominant, '
                           '1 returns to standard cosine annealing.')
  parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
  parser.add_argument('--weight_decay', '--wd', dest='weight_decay', type=float, default=3e-4, help='weight decay')
  parser.add_argument('--partial', default=1/8, type=float, help='partially adaptive parameter p in Padam')
  parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
  parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
  parser.add_argument('--epochs', type=int, default=2000, help='num of training epochs')
  parser.add_argument('--start_epoch', default=1, type=int, metavar='N',
                      help='manual epoch number (useful for restarts)')
  parser.add_argument('--warmup_epochs', type=int, default=5, help='num of warmup training epochs')
  parser.add_argument('--warm_restarts', type=int, default=20, help='warm restarts of cosine annealing')
  parser.add_argument('--init_channels', type=int, default=36, help='num of init channels')
  parser.add_argument('--mid_channels', type=int, default=32, help='C_mid channels in choke SharpSepConv')
  parser.add_argument('--layers', type=int, default=20, help='total number of layers')
  parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
  parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
  parser.add_argument('--mixed_auxiliary', action='store_true', default=False, help='Learn weights for auxiliary networks during training. Overrides auxiliary flag')
  parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
  parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
  parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
  parser.add_argument('--autoaugment', action='store_true', default=False, help='use cifar10 autoaugment https://arxiv.org/abs/1805.09501')
  parser.add_argument('--random_eraser', action='store_true', default=False, help='use random eraser')
  parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
  parser.add_argument('--save', type=str, default='EXP', help='experiment name')
  parser.add_argument('--seed', type=int, default=0, help='random seed')
  parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use')
  parser.add_argument('--ops', type=str, default='OPS', help='which operations to use, options are OPS and DARTS_OPS')
  parser.add_argument('--primitives', type=str, default='PRIMITIVES',
                      help='which primitive layers to use inside a cell search space,'
                           ' options are PRIMITIVES, SHARPER_PRIMITIVES, and DARTS_PRIMITIVES')
  parser.add_argument('--optimizer', type=str, default='sgd', help='which optimizer to use, options are padam and sgd')
  parser.add_argument('--load', type=str, default='',  metavar='PATH', help='load weights at specified location')
  parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
  parser.add_argument('--flops', action='store_true', default=False, help='count flops and exit, aka floating point operations.')
  parser.add_argument('-e', '--evaluate', dest='evaluate', type=str, metavar='PATH', default='',
                      help='evaluate model at specified path on training, test, and validation datasets')
  parser.add_argument('--multi_channel', action='store_true', default=False, help='perform multi channel search, a completely separate search space')
  parser.add_argument('--load_args', type=str, default='',  metavar='PATH',
                      help='load command line args from a json file, this will override '
                           'all currently set args except for --evaluate, and arguments '
                           'that did not exist when the json file was originally saved out.')
  parser.add_argument('--layers_of_cells', type=int, default=8, help='total number of cells in the whole network, default is 8 cells')
  parser.add_argument('--layers_in_cells', type=int, default=4,
                      help='Total number of nodes in each cell, aka number of steps,'
                           ' default is 4 nodes, which implies 8 ops')
  parser.add_argument('--weighting_algorithm', type=str, default='scalar',
                    help='which operations to use, options are '
                         '"max_w" (1. - max_w + w) * op, and scalar (w * op)')
  # TODO(ahundt) remove final path and switch back to genotype
  parser.add_argument('--load_genotype', type=str, default=None, help='Name of genotype to be used')
  parser.add_argument('--simple_path', default=True, action='store_false', help='Final model is a simple path (MultiChannelNetworkModel)')
  args = parser.parse_args()

  args = utils.initialize_files_and_args(args)

  logger = utils.logging_setup(args.log_file_path)

  if not torch.cuda.is_available():
    logger.info('no gpu device available')
    sys.exit(1)

  np.random.seed(args.seed)
  torch.cuda.set_device(args.gpu)
  cudnn.benchmark = True
  torch.manual_seed(args.seed)
  cudnn.enabled=True
  torch.cuda.manual_seed(args.seed)
  logger.info('gpu device = %d' % args.gpu)
  logger.info("args = %s", args)

  DATASET_CLASSES = dataset.class_dict[args.dataset]
  DATASET_CHANNELS = dataset.inp_channel_dict[args.dataset]
  DATASET_MEAN = dataset.mean_dict[args.dataset]
  DATASET_STD = dataset.std_dict[args.dataset]
  logger.info('output channels: ' + str(DATASET_CLASSES))

  # # load the correct ops dictionary
  op_dict_to_load = "operations.%s" % args.ops
  logger.info('loading op dict: ' + str(op_dict_to_load))
  op_dict = eval(op_dict_to_load)

  # load the correct primitives list
  primitives_to_load = "genotypes.%s" % args.primitives
  logger.info('loading primitives:' + primitives_to_load)
  primitives = eval(primitives_to_load)
  logger.info('primitives: ' + str(primitives))

  genotype = eval("genotypes.%s" % args.arch)
  # create the neural network

  criterion = nn.CrossEntropyLoss()
  criterion = criterion.cuda()
  if args.multi_channel:
    final_path = None
    if args.load_genotype is not None:
      genotype = getattr(genotypes, args.load_genotype)
      print(genotype)
      if type(genotype[0]) is str:
        logger.info('Path :%s', genotype)
    # TODO(ahundt) remove final path and switch back to genotype
    cnn_model = MultiChannelNetwork(
      args.init_channels, DATASET_CLASSES, layers=args.layers_of_cells, criterion=criterion, steps=args.layers_in_cells,
      weighting_algorithm=args.weighting_algorithm, genotype=genotype)
    flops_shape = [1, 3, 32, 32]
  elif args.dataset == 'imagenet':
      cnn_model = NetworkImageNet(args.init_channels, DATASET_CLASSES, args.layers, args.auxiliary, genotype, op_dict=op_dict, C_mid=args.mid_channels)
      flops_shape = [1, 3, 224, 224]
  else:
      cnn_model = NetworkCIFAR(args.init_channels, DATASET_CLASSES, args.layers, args.auxiliary, genotype, op_dict=op_dict, C_mid=args.mid_channels)
      flops_shape = [1, 3, 32, 32]
  cnn_model = cnn_model.cuda()

  logger.info("param size = %fMB", utils.count_parameters_in_MB(cnn_model))
  if args.flops:
    logger.info('flops_shape = ' + str(flops_shape))
    logger.info("flops = " + utils.count_model_flops(cnn_model, data_shape=flops_shape))
    return

  optimizer = torch.optim.SGD(
      cnn_model.parameters(),
      args.learning_rate,
      momentum=args.momentum,
      weight_decay=args.weight_decay
      )

  # Get preprocessing functions (i.e. transforms) to apply on data
  train_transform, valid_transform = utils.get_data_transforms(args)
  if args.evaluate:
    # evaluate the train dataset without augmentation
    train_transform = valid_transform

  # Get the training queue, use full training and test set
  train_queue, valid_queue = dataset.get_training_queues(
    args.dataset, train_transform, valid_transform, args.data, args.batch_size, train_proportion=1.0, search_architecture=False)

  test_queue = None
  if args.dataset == 'cifar10':
    # evaluate best model weights on cifar 10.1
    # https://github.com/modestyachts/CIFAR-10.1
    test_data = cifar10_1.CIFAR10_1(root=args.data, download=True, transform=valid_transform)
    test_queue = torch.utils.data.DataLoader(
      test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8)

  if args.evaluate:
    # evaluate the loaded model, print the result, and return
    logger.info("Evaluating inference with weights file: " + args.load)
    eval_stats = evaluate(
      args, cnn_model, criterion, args.load,
      train_queue=train_queue, valid_queue=valid_queue, test_queue=test_queue)
    with open(args.stats_file, 'w') as f:
      arg_dict = vars(args)
      arg_dict.update(eval_stats)
      json.dump(arg_dict, f)
    logger.info("flops = " + utils.count_model_flops(cnn_model))
    logger.info(utils.dict_to_log_string(eval_stats))
    logger.info('\nEvaluation of Loaded Model Complete! Save dir: ' + str(args.save))
    return

  lr_schedule = cosine_power_annealing(
    epochs=args.epochs, max_lr=args.learning_rate, min_lr=args.learning_rate_min,
    warmup_epochs=args.warmup_epochs, exponent_order=args.lr_power_annealing_exponent_order)
  epochs = np.arange(args.epochs) + args.start_epoch
  # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))
  epoch_stats = []

  stats_csv = args.epoch_stats_file
  stats_csv = stats_csv.replace('.json', '.csv')
  with tqdm(epochs, dynamic_ncols=True) as prog_epoch:
    best_valid_acc = 0.0
    best_epoch = 0
    best_stats = {}
    stats = {}
    epoch_stats = []
    weights_file = os.path.join(args.save, 'weights.pt')
    for epoch, learning_rate in zip(prog_epoch, lr_schedule):
      # update the drop_path_prob augmentation
      cnn_model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
      # update the learning rate
      for param_group in optimizer.param_groups:
        param_group['lr'] = learning_rate
      # scheduler.get_lr()[0]

      train_acc, train_obj = train(args, train_queue, cnn_model, criterion, optimizer)

      val_stats = infer(args, valid_queue, cnn_model, criterion)
      stats.update(val_stats)
      stats['train_acc'] = train_acc
      stats['train_loss'] = train_obj
      stats['lr'] = learning_rate
      stats['epoch'] = epoch

      if stats['valid_acc'] > best_valid_acc:
        # new best epoch, save weights
        utils.save(cnn_model, weights_file)
        best_epoch = epoch
        best_stats.update(copy.deepcopy(stats))
        best_valid_acc = stats['valid_acc']
        best_train_loss = train_obj
        best_train_acc = train_acc
      # else:
      #   # not best epoch, load best weights
      #   utils.load(cnn_model, weights_file)
      logger.info('epoch, %d, train_acc, %f, valid_acc, %f, train_loss, %f, valid_loss, %f, lr, %e, best_epoch, %d, best_valid_acc, %f, ' + utils.dict_to_log_string(stats),
                  epoch, train_acc, stats['valid_acc'], train_obj, stats['valid_loss'], learning_rate, best_epoch, best_valid_acc)
      stats['train_acc'] = train_acc
      stats['train_loss'] = train_obj
      epoch_stats += [copy.deepcopy(stats)]
      with open(args.epoch_stats_file, 'w') as f:
        json.dump(epoch_stats, f, cls=utils.NumpyEncoder)
      utils.list_of_dicts_to_csv(stats_csv, epoch_stats)

    # get stats from best epoch including cifar10.1
    eval_stats = evaluate(args, cnn_model, criterion, weights_file, train_queue, valid_queue, test_queue)
    with open(args.stats_file, 'w') as f:
      arg_dict = vars(args)
      arg_dict.update(eval_stats)
      json.dump(arg_dict, f, cls=utils.NumpyEncoder)
    with open(args.epoch_stats_file, 'w') as f:
      json.dump(epoch_stats, f, cls=utils.NumpyEncoder)
    logger.info(utils.dict_to_log_string(eval_stats))
    logger.info('Training of Final Model Complete! Save dir: ' + str(args.save))
Exemplo n.º 5
0
def main():
    global best_top1, args, logger

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    # commented because it is now set as an argparse param.
    # args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    # note the gpu is used for directory creation and log files
    # which is needed when run as multiple processes
    args = utils.initialize_files_and_args(args)
    logger = utils.logging_setup(args.log_file_path)

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if args.static_loss_scale != 1.0:
        if not args.fp16:
            logger.info(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    # # load the correct ops dictionary
    op_dict_to_load = "operations.%s" % args.ops
    logger.info('loading op dict: ' + str(op_dict_to_load))
    op_dict = eval(op_dict_to_load)

    # load the correct primitives list
    primitives_to_load = "genotypes.%s" % args.primitives
    logger.info('loading primitives:' + primitives_to_load)
    primitives = eval(primitives_to_load)
    logger.info('primitives: ' + str(primitives))
    # create model
    genotype = eval("genotypes.%s" % args.arch)
    # get the number of output channels
    classes = dataset.class_dict[args.dataset]
    # create the neural network
    if args.dataset == 'imagenet':
        model = NetworkImageNet(args.init_channels,
                                classes,
                                args.layers,
                                args.auxiliary,
                                genotype,
                                op_dict=op_dict,
                                C_mid=args.mid_channels)
        flops_shape = [1, 3, 224, 224]
    else:
        model = NetworkCIFAR(args.init_channels,
                             classes,
                             args.layers,
                             args.auxiliary,
                             genotype,
                             op_dict=op_dict,
                             C_mid=args.mid_channels)
        flops_shape = [1, 3, 32, 32]
    model.drop_path_prob = 0.0
    # if args.pretrained:
    #     logger.info("=> using pre-trained model '{}'".format(args.arch))
    #     model = models.__dict__[args.arch](pretrained=True)
    # else:
    #     logger.info("=> creating model '{}'".format(args.arch))
    #     model = models.__dict__[args.arch]()

    if args.flops:
        model = model.cuda()
        logger.info("param size = %fMB", utils.count_parameters_in_MB(model))
        logger.info("flops_shape = " + str(flops_shape))
        logger.info("flops = " +
                    utils.count_model_flops(model, data_shape=flops_shape))
        return

    if args.sync_bn:
        import apex
        logger.info("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()
    if args.fp16:
        model = network_to_half(model)
    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DDP(model, delay_allreduce=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # Scale learning rate based on global batch size
    args.learning_rate = args.learning_rate * float(
        args.batch_size * args.world_size) / 256.
    init_lr = args.learning_rate / args.warmup_lr_divisor
    optimizer = torch.optim.SGD(model.parameters(),
                                init_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # epoch_count = args.epochs - args.start_epoch
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(epoch_count))
    # scheduler = warmup_scheduler.GradualWarmupScheduler(
    #     optimizer, args.warmup_lr_divisor, args.warmup_epochs, scheduler)

    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale)

    # Optionally resume from a checkpoint
    if args.resume or args.evaluate:
        if args.evaluate:
            args.resume = args.evaluate
        # Use a local scope to avoid dangling references
        def resume():
            if os.path.isfile(args.resume):
                logger.info("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(
                    args.resume,
                    map_location=lambda storage, loc: storage.cuda(args.gpu))
                args.start_epoch = checkpoint['epoch']
                if 'best_top1' in checkpoint:
                    best_top1 = checkpoint['best_top1']
                model.load_state_dict(checkpoint['state_dict'])
                # An FP16_Optimizer instance's state dict internally stashes the master params.
                optimizer.load_state_dict(checkpoint['optimizer'])
                # TODO(ahundt) make sure scheduler loading isn't broken
                if 'lr_scheduler' in checkpoint:
                    scheduler.load_state_dict(checkpoint['lr_scheduler'])
                elif 'lr_schedule' in checkpoint:
                    lr_schedule = checkpoint['lr_schedule']
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                logger.info("=> no checkpoint found at '{}'".format(
                    args.resume))

        resume()

    # # Data loading code
    # traindir = os.path.join(args.data, 'train')
    # valdir = os.path.join(args.data, 'val')

    # if(args.arch == "inception_v3"):
    #     crop_size = 299
    #     val_size = 320 # I chose this value arbitrarily, we can adjust.
    # else:
    #     crop_size = 224
    #     val_size = 256

    # train_dataset = datasets.ImageFolder(
    #     traindir,
    #     transforms.Compose([
    #         transforms.RandomResizedCrop(crop_size),
    #         transforms.RandomHorizontalFlip(),
    #         autoaugment.ImageNetPolicy(),
    #         # transforms.ToTensor(),  # Too slow, moved to data_prefetcher()
    #         # normalize,
    #     ]))
    # val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
    #         transforms.Resize(val_size),
    #         transforms.CenterCrop(crop_size)
    #     ]))

    # train_sampler = None
    # val_sampler = None
    # if args.distributed:
    #     train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    #     val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)

    # train_loader = torch.utils.data.DataLoader(
    #     train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
    #     num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)

    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset,
    #     batch_size=args.batch_size, shuffle=False,
    #     num_workers=args.workers, pin_memory=True,
    #     sampler=val_sampler,
    #     collate_fn=fast_collate)

    # Get preprocessing functions (i.e. transforms) to apply on data
    # normalize_as_tensor = False because we normalize and convert to a
    # tensor in our custom prefetching function, rather than as part of
    # the transform preprocessing list.
    train_transform, valid_transform = utils.get_data_transforms(
        args, normalize_as_tensor=False)
    # Get the training queue, select training and validation from training set
    train_loader, val_loader = dataset.get_training_queues(
        args.dataset,
        train_transform,
        valid_transform,
        args.data,
        args.batch_size,
        train_proportion=1.0,
        collate_fn=fast_collate,
        distributed=args.distributed,
        num_workers=args.workers)

    if args.evaluate:
        if args.dataset == 'cifar10':
            # evaluate best model weights on cifar 10.1
            # https://github.com/modestyachts/CIFAR-10.1
            train_transform, valid_transform = utils.get_data_transforms(args)
            # Get the training queue, select training and validation from training set
            # Get the training queue, use full training and test set
            train_queue, valid_queue = dataset.get_training_queues(
                args.dataset,
                train_transform,
                valid_transform,
                args.data,
                args.batch_size,
                train_proportion=1.0,
                search_architecture=False)
            test_data = cifar10_1.CIFAR10_1(root=args.data,
                                            download=True,
                                            transform=valid_transform)
            test_queue = torch.utils.data.DataLoader(
                test_data,
                batch_size=args.batch_size,
                shuffle=False,
                pin_memory=True,
                num_workers=args.workers)
            eval_stats = evaluate(args,
                                  model,
                                  criterion,
                                  train_queue=train_queue,
                                  valid_queue=valid_queue,
                                  test_queue=test_queue)
            with open(args.stats_file, 'w') as f:
                # TODO(ahundt) fix "TypeError: 1869 is not JSON serializable" to include arg info, see train.py
                # arg_dict = vars(args)
                # arg_dict.update(eval_stats)
                # json.dump(arg_dict, f)
                json.dump(eval_stats, f)
            logger.info("flops = " + utils.count_model_flops(model))
            logger.info(utils.dict_to_log_string(eval_stats))
            logger.info('\nEvaluation of Loaded Model Complete! Save dir: ' +
                        str(args.save))
        else:
            validate(val_loader, model, criterion, args)
        return

    lr_schedule = cosine_power_annealing(
        epochs=args.epochs,
        max_lr=args.learning_rate,
        min_lr=args.learning_rate_min,
        warmup_epochs=args.warmup_epochs,
        exponent_order=args.lr_power_annealing_exponent_order,
        restart_lr=args.restart_lr)
    epochs = np.arange(args.epochs) + args.start_epoch

    stats_csv = args.epoch_stats_file
    stats_csv = stats_csv.replace('.json', '.csv')
    with tqdm(epochs,
              dynamic_ncols=True,
              disable=args.local_rank != 0,
              leave=False) as prog_epoch:
        best_stats = {}
        stats = {}
        epoch_stats = []
        best_epoch = 0
        for epoch, learning_rate in zip(prog_epoch, lr_schedule):
            if args.distributed and train_loader.sampler is not None:
                train_loader.sampler.set_epoch(int(epoch))
            # if args.distributed:
            # train_sampler.set_epoch(epoch)
            # update the learning rate
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate
            # scheduler.step()
            model.drop_path_prob = args.drop_path_prob * float(epoch) / float(
                args.epochs)
            # train for one epoch
            train_stats = train(train_loader, model, criterion, optimizer,
                                int(epoch), args)
            if args.prof:
                break
            # evaluate on validation set
            top1, val_stats = validate(val_loader, model, criterion, args)
            stats.update(train_stats)
            stats.update(val_stats)
            # stats['lr'] = '{0:.5f}'.format(scheduler.get_lr()[0])
            stats['lr'] = '{0:.5f}'.format(learning_rate)
            stats['epoch'] = epoch

            # remember best top1 and save checkpoint
            if args.local_rank == 0:
                is_best = top1 > best_top1
                best_top1 = max(top1, best_top1)
                stats['best_top1'] = '{0:.3f}'.format(best_top1)
                if is_best:
                    best_epoch = epoch
                    best_stats = copy.deepcopy(stats)
                stats['best_epoch'] = best_epoch

                stats_str = utils.dict_to_log_string(stats)
                logger.info(stats_str)
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'best_top1': best_top1,
                        'optimizer': optimizer.state_dict(),
                        # 'lr_scheduler': scheduler.state_dict()
                        'lr_schedule': lr_schedule,
                        'stats': best_stats
                    },
                    is_best,
                    path=args.save)
                prog_epoch.set_description(
                    'Overview ***** best_epoch: {0} best_valid_top1: {1:.2f} ***** Progress'
                    .format(best_epoch, best_top1))
            epoch_stats += [copy.deepcopy(stats)]
            with open(args.epoch_stats_file, 'w') as f:
                json.dump(epoch_stats, f, cls=utils.NumpyEncoder)
            utils.list_of_dicts_to_csv(stats_csv, epoch_stats)
        stats_str = utils.dict_to_log_string(best_stats, key_prepend='best_')
        logger.info(stats_str)
        with open(args.stats_file, 'w') as f:
            arg_dict = vars(args)
            arg_dict.update(best_stats)
            json.dump(arg_dict, f, cls=utils.NumpyEncoder)
        with open(args.epoch_stats_file, 'w') as f:
            json.dump(epoch_stats, f, cls=utils.NumpyEncoder)
        utils.list_of_dicts_to_csv(stats_csv, epoch_stats)
        logger.info('Training of Final Model Complete! Save dir: ' +
                    str(args.save))
Exemplo n.º 6
0
def main():
    if not torch.cuda.is_available():
        logger.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logger.info('gpu device = %d' % args.gpu)
    logger.info("args = %s", args)

    # # load the correct ops dictionary
    op_dict_to_load = "operations.%s" % args.ops
    logger.info('loading op dict: ' + str(op_dict_to_load))
    op_dict = eval(op_dict_to_load)

    # load the correct primitives list
    primitives_to_load = "genotypes.%s" % args.primitives
    logger.info('loading primitives:' + primitives_to_load)
    primitives = eval(primitives_to_load)
    logger.info('primitives: ' + str(primitives))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    if args.multi_channel:
        final_path = None
        if args.final_path is not None:
            final_path = np.load(args.final_path)

        genotype = None
        if args.load_genotype is not None:
            genotype = getattr(genotypes, args.load_genotype)
        cnn_model = model_search.MultiChannelNetwork(
            args.init_channels,
            CIFAR_CLASSES,
            layers=args.layers_of_cells,
            criterion=criterion,
            steps=args.layers_in_cells,
            primitives=primitives,
            op_dict=op_dict,
            weighting_algorithm=args.weighting_algorithm,
            genotype=genotype)
        #save_graph(cnn_model.G, os.path.join(args.save, 'network_graph.pdf'))
        if args.load_genotype is not None:
            # TODO(ahundt) support other batch shapes
            data_shape = [1, 3, 32, 32]
            batch = torch.zeros(data_shape)
            cnn_model(batch)
            logger.info("loaded genotype_raw_weights = " +
                        str(cnn_model.genotype('raw_weights')))
            logger.info("loaded genotype_longest_path = " +
                        str(cnn_model.genotype('longest_path')))
            logger.info("loaded genotype greedy_path = " +
                        str(gen_greedy_path(cnn_model.G, strategy="top_down")))
            logger.info(
                "loaded genotype greedy_path_bottom_up = " +
                str(gen_greedy_path(cnn_model.G, strategy="bottom_up")))
            # TODO(ahundt) support other layouts
    else:
        cnn_model = model_search.Network(
            args.init_channels,
            CIFAR_CLASSES,
            layers=args.layers_of_cells,
            criterion=criterion,
            steps=args.layers_in_cells,
            primitives=primitives,
            op_dict=op_dict,
            weights_are_parameters=args.no_architect,
            C_mid=args.mid_channels,
            weighting_algorithm=args.weighting_algorithm)
    cnn_model = cnn_model.cuda()
    logger.info("param size = %fMB", utils.count_parameters_in_MB(cnn_model))

    if args.load:
        logger.info('loading weights from: ' + args.load)
        utils.load(cnn_model, args.load)

    optimizer = torch.optim.SGD(cnn_model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # Get preprocessing functions (i.e. transforms) to apply on data
    train_transform, valid_transform = utils.get_data_transforms(args)

    # Get the training queue, select training and validation from training set
    train_queue, valid_queue = dataset.get_training_queues(
        args.dataset,
        train_transform,
        valid_transform,
        args.data,
        args.batch_size,
        args.train_portion,
        search_architecture=True)

    lr_schedule = cosine_power_annealing(
        epochs=args.epochs,
        max_lr=args.learning_rate,
        min_lr=args.learning_rate_min,
        warmup_epochs=args.warmup_epochs,
        exponent_order=args.lr_power_annealing_exponent_order)
    epochs = np.arange(args.epochs) + args.start_epoch
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    #       optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    if args.no_architect:
        architect = None
    else:
        architect = Architect(cnn_model, args)

    epoch_stats = []

    stats_csv = args.epoch_stats_file
    stats_csv = stats_csv.replace('.json', '.csv')
    with tqdm(epochs, dynamic_ncols=True) as prog_epoch:
        best_valid_acc = 0.0
        best_epoch = 0
        # state_dict = {}
        # og_state_keys = set()
        # updated_state_keys = set()

        #saving state_dict for debugging weights by comparison
        # for key in cnn_model.state_dict():
        #   state_dict[key] = cnn_model.state_dict()[key].clone()
        #   # logger.info('layer = {}'.format(key))
        # logger.info('Total keys in state_dict = {}'.format(len(cnn_model.state_dict().keys())))
        # og_state_keys.update(cnn_model.state_dict().keys())
        best_stats = {}
        weights_file = os.path.join(args.save, 'weights.pt')
        for epoch, learning_rate in zip(prog_epoch, lr_schedule):
            # scheduler.step()
            # lr = scheduler.get_lr()[0]
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate
            genotype = None
            if args.final_path is None:
                genotype = cnn_model.genotype()
                logger.info('genotype = %s', genotype)

            if not args.multi_channel:
                # the genotype is the alphas in the multi-channel case
                # print the alphas in other cases
                logger.info('alphas_normal = %s', cnn_model.arch_weights(0))
                logger.info('alphas_reduce = %s', cnn_model.arch_weights(1))

            # training
            train_acc, train_obj = train(train_queue, valid_queue, cnn_model,
                                         architect, criterion, optimizer,
                                         learning_rate)

            if args.multi_channel and args.final_path is None:
                # TODO(ahundt) remove final path and switch back to genotype, and save out raw weights plus optimal path
                optimal_path = nx.algorithms.dag.dag_longest_path(cnn_model.G)
                optimal_path_filename = os.path.join(
                    args.save, 'longest_path_layer_sequence.npy')
                logger.info('Saving model layer sequence object: ' +
                            str(optimal_path_filename))
                np.save(optimal_path_filename, optimal_path)
                graph_filename = os.path.join(
                    args.save, 'network_graph_' + str(epoch) + '.graph')
                logger.info('Saving updated weight graph: ' +
                            str(graph_filename))
                nx.write_gpickle(cnn_model.G, graph_filename)
                logger.info('optimal_path  : %s', optimal_path)

            # validation
            valid_acc, valid_obj = infer(valid_queue, cnn_model, criterion)

            if valid_acc > best_valid_acc:
                # new best epoch, save weights

                utils.save(cnn_model, weights_file)

                if args.multi_channel:

                    graph_filename = os.path.join(
                        args.save,
                        'network_graph_best_valid' + str(epoch) + '.graph')
                    logger.info('Saving updated weight graph: ' +
                                str(graph_filename))

                best_epoch = epoch
                best_valid_acc = valid_acc
                prog_epoch.set_description(
                    'Overview ***** best_epoch: {0} best_valid_acc: {1:.2f} ***** Progress'
                    .format(best_epoch, best_valid_acc))

            logger.info(
                'epoch, %d, train_acc, %f, valid_acc, %f, train_loss, %f, valid_loss, %f, lr, %e, best_epoch, %d, best_valid_acc, %f',
                epoch, train_acc, valid_acc, train_obj, valid_obj,
                learning_rate, best_epoch, best_valid_acc)
            stats = {
                'epoch': epoch,
                'train_acc': train_acc,
                'valid_acc': valid_acc,
                'train_loss': train_obj,
                'valid_loss': valid_obj,
                'lr': learning_rate,
                'best_epoch': best_epoch,
                'best_valid_acc': best_valid_acc,
                'genotype': str(genotype),
                'arch_weights': str(cnn_model.arch_weights)
            }
            epoch_stats += [copy.deepcopy(stats)]
            with open(args.epoch_stats_file, 'w') as f:
                json.dump(epoch_stats, f, cls=utils.NumpyEncoder)
            utils.list_of_dicts_to_csv(stats_csv, epoch_stats)

    # print the final model
    if args.final_path is None:
        genotype = cnn_model.genotype()
        logger.info('genotype = %s', genotype)
    logger.info('Search for Model Complete! Save dir: ' + str(args.save))
Exemplo n.º 7
0
def fetch_entire_dataloader(
        source,
        data_dir,
        val_scale,
        seed,
        batch_size,
        num_workers,
        gaze_task=None,  #either none, data augment, cam reg, cam reg convex
        ood_set=None,
        ood_shift=None,
        subclass=False,
        gan=True,
        label_class=None):

    transforms = get_data_transforms("cxr",
                                     normalization_type="train_images",
                                     gan=gan)

    datasets = []
    for split in ["train", "val", "test"]:

        if ood_set is not None:
            if split == "test":
                source = f"{ood_set}/{source}/{ood_shift}"

        dataset = RoboGazeDataset(source=source,
                                  data_dir=data_dir,
                                  split_type=split,
                                  gaze_task=gaze_task,
                                  transform=transforms[split],
                                  val_scale=val_scale,
                                  seed=seed,
                                  subclass=subclass,
                                  gan=gan)

        datasets.append(dataset)

    concat_datasets = torch.utils.data.ConcatDataset(datasets)

    if label_class is not None:
        full_dataloader = DataLoader(
            dataset=concat_datasets,
            shuffle=False,
            batch_size=1,
            num_workers=num_workers,
        )
        class_indices = []
        for idx, (img, label) in enumerate(full_dataloader):
            if label.item() == label_class:
                class_indices.append(idx)

        class_dataset = torch.utils.data.Subset(concat_datasets, class_indices)

        return DataLoader(
            dataset=class_dataset,
            shuffle=True,
            batch_size=batch_size,
            num_workers=num_workers,
        )
    else:
        return DataLoader(
            dataset=concat_datasets,
            shuffle=True,
            batch_size=batch_size,
            num_workers=num_workers,
        )
Exemplo n.º 8
0
def fetch_dataloaders(
        source,
        data_dir,
        val_scale,
        seed,
        batch_size,
        num_workers,
        gaze_task=None,  #either none, data augment, cam reg, cam reg convex actdiff
        ood_set=None,
        ood_shift=None,
        subclass=False,
        gan_positive=None,
        gan_negative=None,
        gan_type=None,
        args=None):

    transforms = get_data_transforms("cxr", normalization_type="train_images")
    dataloaders = {}

    for split in ["train", "val", "test"]:

        if ood_set is not None:
            if split == "test":

                source = f"{args.machine}/{ood_set}/{source}/{ood_shift}"
                #source = f"{ood_set}/{source}/{ood_shift}"

        if split == 'train':
            if gan_positive is not None:
                original_dataset = RoboGazeDataset(source=source,
                                                   data_dir=data_dir,
                                                   split_type=split,
                                                   gaze_task=gaze_task,
                                                   transform=transforms[split],
                                                   val_scale=val_scale,
                                                   seed=seed,
                                                   subclass=subclass,
                                                   args=args)

                ## get positive and negative class amounts
                original_dataloader = DataLoader(
                    dataset=original_dataset,
                    shuffle=False,
                    batch_size=1,
                    num_workers=num_workers,
                )

                class_amounts = [0, 0]
                for img, label, _ in original_dataloader:
                    class_amounts[label.item()] += 1

                if gan_type == "gan":
                    pos_generator = gan_generator.Generator_Advanced_224(
                    ).cuda()
                    neg_generator = gan_generator.Generator_Advanced_224(
                    ).cuda()
                    noise_size = 100

                    pos_generator.load_state_dict(
                        torch.load(gan_positive + '/generator_best_ckpt.pt'))
                    neg_generator.load_state_dict(
                        torch.load(gan_negative + '/generator_best_ckpt.pt'))

                    neg_noise = torch.randn(class_amounts[0], noise_size, 1,
                                            1).cuda()
                    pos_noise = torch.randn(class_amounts[1], noise_size, 1,
                                            1).cuda()

                    # Feed noise into the generator to create new images

                    neg_images = []
                    for i in range(class_amounts[0]):
                        neg_images.append(
                            neg_generator(
                                neg_noise[i].unsqueeze(dim=0)).detach().cpu())
                    neg_images = torch.cat(neg_images)

                    pos_images = []
                    for i in range(class_amounts[1]):
                        pos_images.append(
                            pos_generator(
                                pos_noise[i].unsqueeze(dim=0)).detach().cpu())
                    pos_images = torch.cat(pos_images)

                    #neg_images = neg_generator(neg_noise).detach()
                    #pos_images = pos_generator(pos_noise).detach()

                    neg_labels = torch.zeros(neg_images.shape[0]).cpu().numpy()
                    pos_labels = torch.ones(pos_images.shape[0]).cpu().numpy()

                    neg_gaze_attr = torch.zeros(
                        neg_images.shape[0]).cpu().numpy()
                    pos_gaze_attr = torch.zeros(
                        neg_images.shape[0]).cpu().numpy()

                    positive_fake_data = GanDataset(images=pos_images,
                                                    labels=pos_labels,
                                                    gaze_attr=pos_gaze_attr)
                    negative_fake_data = GanDataset(images=neg_images,
                                                    labels=neg_labels,
                                                    gaze_attr=neg_gaze_attr)

                elif gan_type == "acgan":
                    noise_size = 110
                    pos_generator = acgan_generator.Generator_Advanced_224(
                        1, noise_size).cuda()
                    neg_generator = acgan_generator.Generator_Advanced_224(
                        1, noise_size).cuda()

                    pos_generator.load_state_dict(
                        torch.load(gan_positive + '/generator_best_ckpt.pt'))
                    neg_generator.load_state_dict(
                        torch.load(gan_negative + '/generator_best_ckpt.pt'))

                    neg_noise = torch.randn(class_amounts[0], noise_size, 1,
                                            1).cuda()
                    pos_noise = torch.randn(class_amounts[1], noise_size, 1,
                                            1).cuda()

                    # Feed noise into the generator to create new images
                    neg_images = []
                    for i in range(class_amounts[0]):
                        neg_images.append(
                            neg_generator(
                                neg_noise[i].unsqueeze(dim=0)).detach().cpu())
                    neg_images = torch.cat(neg_images)

                    pos_images = []
                    for i in range(class_amounts[1]):
                        pos_images.append(
                            pos_generator(
                                pos_noise[i].unsqueeze(dim=0)).detach().cpu())
                    pos_images = torch.cat(pos_images)

                    #neg_images = neg_generator(neg_noise).detach()
                    #pos_images = pos_generator(pos_noise).detach()

                    neg_labels = torch.zeros(neg_images.shape[0]).cpu().numpy()
                    pos_labels = torch.ones(pos_images.shape[0]).cpu().numpy()

                    neg_gaze_attr = torch.zeros(
                        neg_images.shape[0]).cpu().numpy()
                    pos_gaze_attr = torch.zeros(
                        neg_images.shape[0]).cpu().numpy()

                    positive_fake_data = GanDataset(images=pos_images,
                                                    labels=pos_labels,
                                                    gaze_attr=pos_gaze_attr)
                    negative_fake_data = GanDataset(images=neg_images,
                                                    labels=neg_labels,
                                                    gaze_attr=neg_gaze_attr)

                dataset = torch.utils.data.ConcatDataset(
                    [original_dataset, positive_fake_data, negative_fake_data])

            else:
                if gaze_task == "actdiff" and source == "synth":
                    dataset = SyntheticDataset(split=split, blur=0.817838)
                else:
                    dataset = RoboGazeDataset(source=source,
                                              data_dir=data_dir,
                                              split_type=split,
                                              gaze_task=gaze_task,
                                              transform=transforms[split],
                                              val_scale=val_scale,
                                              seed=seed,
                                              subclass=subclass,
                                              args=args)
        else:
            if gaze_task == "actdiff" and source == "synth":
                dataset = SyntheticDataset(split=split, blur=0.817838)
            else:
                dataset = RoboGazeDataset(source=source,
                                          data_dir=data_dir,
                                          split_type=split,
                                          gaze_task=gaze_task,
                                          transform=transforms[split],
                                          val_scale=val_scale,
                                          seed=seed,
                                          subclass=subclass,
                                          args=args)

        dataloaders[split] = (DataLoader(
            dataset=dataset,
            shuffle=split == "train",
            batch_size=batch_size,
            num_workers=num_workers,
        ))

    return dataloaders
Exemplo n.º 9
0
                        type=int,
                        help='Number of images to use for evaluation')
    parser.add_argument('--using-pretrained',
                        default=False,
                        action='store_true',
                        help='Train from scratch or just the final layer?')
    args = parser.parse_args()

    # CPU or GPU device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using device {}".format(device))

    # Args for loading data
    data_dir = args.data_dir
    batch_size = args.batch_size
    data_transforms = get_data_transforms()

    # Load in and transform data
    image_datasets = {
        x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
        for x in ['train', 'val', 'test']
    }

    # Maps class_idx to its string label
    classes_to_idx = {
        x: {
            class_idx: class_name
            for class_name, class_idx in
            image_datasets[x].class_to_idx.items()
        }
        for x in ['train', 'val', 'test']