Exemple #1
0
def build_dataset():
    normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                     std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
    if args.augment:
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
                                              (4, 4, 4, 4), mode='reflect').squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])

    if args.dataset == 'cifar10':
        train_data_meta = CIFAR10(
            root='../data', train=True, meta=True, num_meta=args.num_meta, corruption_prob=args.corruption_prob,
            corruption_type=args.corruption_type, transform=train_transform, download=True)
        train_data = CIFAR10(
            root='../data', train=True, meta=False, num_meta=args.num_meta, corruption_prob=args.corruption_prob,
            corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed)
        test_data = CIFAR10(root='../data', train=False, transform=test_transform, download=True)


    elif args.dataset == 'cifar100':
        train_data_meta = CIFAR100(
            root='../data', train=True, meta=True, num_meta=args.num_meta, corruption_prob=args.corruption_prob,
            corruption_type=args.corruption_type, transform=train_transform, download=True)
        train_data = CIFAR100(
            root='../data', train=True, meta=False, num_meta=args.num_meta, corruption_prob=args.corruption_prob,
            corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed)
        test_data = CIFAR100(root='../data', train=False, transform=test_transform, download=True)


    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True,
        num_workers=args.prefetch, pin_memory=True)
    train_meta_loader = torch.utils.data.DataLoader(
        train_data_meta, batch_size=args.batch_size, shuffle=True,
        num_workers=args.prefetch, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                                              num_workers=args.prefetch, pin_memory=True)

    return train_loader, train_meta_loader, test_loader
Exemple #2
0
                                            args.gold_fraction,
                                            args.corruption_prob,
                                            args.corruption_type,
                                            transform=test_transform,
                                            download=True)
    test_data = CIFAR10(args.data_path,
                        train=False,
                        transform=test_transform,
                        download=True)
    num_classes = 10

elif args.dataset == 'cifar100':
    train_data_gold = CIFAR100(args.data_path,
                               True,
                               True,
                               args.gold_fraction,
                               args.corruption_prob,
                               args.corruption_type,
                               transform=train_transform,
                               download=True)
    train_data_silver = CIFAR100(args.data_path,
                                 True,
                                 False,
                                 args.gold_fraction,
                                 args.corruption_prob,
                                 args.corruption_type,
                                 transform=train_transform,
                                 download=True)
    train_data_gold_deterministic = CIFAR100(args.data_path,
                                             True,
                                             True,
                                             args.gold_fraction,
Exemple #3
0
def main():
  if not torch.cuda.is_available():
    logging.info('no gpu device available')
    sys.exit(1)

  random.seed(args.seed)
  np.random.seed(args.data_seed)  # cutout and load_corrupted_data use np.random
  torch.cuda.set_device(args.gpu)
  cudnn.benchmark = False
  torch.manual_seed(args.seed)
  cudnn.enabled = True
  cudnn.deterministic = True
  torch.cuda.manual_seed(args.seed)
  logging.info('gpu device = %d' % args.gpu)
  logging.info("args = %s", args)

  if args.loss_func == 'cce':
    criterion = nn.CrossEntropyLoss().cuda()
  elif args.loss_func == 'rll':
    criterion = utils.RobustLogLoss(alpha=args.alpha).cuda()
  else:
    assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(args.loss_func)

  if args.valid_loss_func == 'cce':
    valid_criterion = nn.CrossEntropyLoss().cuda()
  elif args.valid_loss_func == 'rll':
    valid_criterion = utils.RobustLogLoss(alpha=args.alpha).cuda()
  else:
    valid_criterion = None

  model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion, valid_criterion)
  model = model.cuda()
  logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

  optimizer = torch.optim.SGD(
      model.parameters(),
      args.learning_rate,
      momentum=args.momentum,
      weight_decay=args.weight_decay)

  train_transform, valid_transform = utils._data_transforms_cifar10(args)

  # Load dataset
  if args.dataset == 'cifar10':
    noisy_train_data = CIFAR10(
      root=args.data, train=True, gold=False, gold_fraction=0.0,
      corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
      transform=train_transform, download=True, seed=args.data_seed)
    gold_train_data = CIFAR10(
      root=args.data, train=True, gold=True, gold_fraction=1.0,
      corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
      transform=train_transform, download=True, seed=args.data_seed)
  elif args.dataset == 'cifar100':
    noisy_train_data = CIFAR100(
      root=args.data, train=True, gold=False, gold_fraction=0.0,
      corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
      transform=train_transform, download=True, seed=args.data_seed)
    gold_train_data = CIFAR100(
      root=args.data, train=True, gold=True, gold_fraction=1.0,
      corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
      transform=train_transform, download=True, seed=args.data_seed)
  num_train = len(gold_train_data)
  indices = list(range(num_train))
  split = int(np.floor(args.train_portion * num_train))

  if args.gold_fraction == 1.0:
    train_data = gold_train_data
  else:
    train_data = noisy_train_data
  train_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
      pin_memory=True, num_workers=0)

  if args.clean_valid:
    valid_data = gold_train_data
  else:
    valid_data = noisy_train_data

  valid_queue = torch.utils.data.DataLoader(
    valid_data, batch_size=args.batch_size,
    sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
    pin_memory=True, num_workers=0)

  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

  architect = Architect(model, args)

  for epoch in range(args.epochs):
    scheduler.step()
    lr = scheduler.get_lr()[0]
    logging.info('epoch %d lr %e', epoch, lr)

    genotype = model.genotype()
    logging.info('genotype = %s', genotype)

    print(F.softmax(model.alphas_normal, dim=-1))
    print(F.softmax(model.alphas_reduce, dim=-1))

    # training
    train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr)
    logging.info('train_acc %f', train_acc)

    # validation
    valid_acc, valid_obj = infer(valid_queue, model, criterion)
    logging.info('valid_acc %f', valid_acc)

    utils.save(model, os.path.join(args.save, 'weights.pt'))
Exemple #4
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = False
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    cudnn.deterministic = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss(alpha=args.alpha).cuda()
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(
            args.loss_func)

    model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
    model = model.cuda()
    model.train()
    model.apply(weights_init)
    nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    train_transform, valid_transform = utils._data_transforms_cifar10(args)

    # Load dataset
    if args.dataset == 'cifar10':
        train_data = CIFAR10(root=args.data,
                             train=True,
                             gold=False,
                             gold_fraction=0.0,
                             corruption_prob=args.corruption_prob,
                             corruption_type=args.corruption_type,
                             transform=train_transform,
                             download=True,
                             seed=args.seed)
        gold_train_data = CIFAR10(root=args.data,
                                  train=True,
                                  gold=True,
                                  gold_fraction=1.0,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True,
                                  seed=args.seed)
    elif args.dataset == 'cifar100':
        train_data = CIFAR100(root=args.data,
                              train=True,
                              gold=False,
                              gold_fraction=0.0,
                              corruption_prob=args.corruption_prob,
                              corruption_type=args.corruption_type,
                              transform=train_transform,
                              download=True,
                              seed=args.seed)
        gold_train_data = CIFAR100(root=args.data,
                                   train=True,
                                   gold=True,
                                   gold_fraction=1.0,
                                   corruption_prob=args.corruption_prob,
                                   corruption_type=args.corruption_type,
                                   transform=train_transform,
                                   download=True,
                                   seed=args.seed)
    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    clean_train_queue = torch.utils.data.DataLoader(
        gold_train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=0)
    noisy_train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=0)

    clean_valid_queue = torch.utils.data.DataLoader(
        gold_train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
        pin_memory=True,
        num_workers=0)
    noisy_valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
        pin_memory=True,
        num_workers=0)

    clean_train_list, clean_valid_list, noisy_train_list, noisy_valid_list = [], [], [], []
    for dst_list, queue in [
        (clean_train_list, clean_train_queue),
        (clean_valid_list, clean_valid_queue),
        (noisy_train_list, noisy_train_queue),
        (noisy_valid_list, noisy_valid_queue),
    ]:
        for input, target in queue:
            input = Variable(input, volatile=True).cuda()
            target = Variable(target, volatile=True).cuda(async=True)
            dst_list.append((input, target))

    for epoch in range(args.epochs):
        logging.info('Epoch %d, random architecture with fix weights', epoch)

        genotype = model.genotype()
        logging.info('genotype = %s', genotype)

        logging.info(F.softmax(model.alphas_normal, dim=-1))
        logging.info(F.softmax(model.alphas_reduce, dim=-1))

        # training
        clean_train_acc, clean_train_obj = infer(clean_train_list,
                                                 model,
                                                 criterion,
                                                 kind='clean_train')
        logging.info('clean_train_acc %f, clean_train_loss %f',
                     clean_train_acc, clean_train_obj)

        noisy_train_acc, noisy_train_obj = infer(noisy_train_list,
                                                 model,
                                                 criterion,
                                                 kind='noisy_train')
        logging.info('noisy_train_acc %f, noisy_train_loss %f',
                     noisy_train_acc, noisy_train_obj)

        # validation
        clean_valid_acc, clean_valid_obj = infer(clean_valid_list,
                                                 model,
                                                 criterion,
                                                 kind='clean_valid')
        logging.info('clean_valid_acc %f, clean_valid_loss %f',
                     clean_valid_acc, clean_valid_obj)

        noisy_valid_acc, noisy_valid_obj = infer(noisy_valid_list,
                                                 model,
                                                 criterion,
                                                 kind='noisy_valid')
        logging.info('noisy_valid_acc %f, noisy_valid_loss %f',
                     noisy_valid_acc, noisy_valid_obj)

        utils.save(model, os.path.join(args.save, 'weights.pt'))

        # Randomly change the alphas
        k = sum(1 for i in range(model._steps) for n in range(2 + i))
        num_ops = len(PRIMITIVES)
        model.alphas_normal.data.copy_(torch.randn(k, num_ops))
        model.alphas_reduce.data.copy_(torch.randn(k, num_ops))
Exemple #5
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    random.seed(args.seed)
    np.random.seed(
        args.data_seed)  # cutout and load_corrupted_data use np.random
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = False
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    cudnn.deterministic = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    if args.arch == 'resnet':
        model = ResNet18(CIFAR_CLASSES).cuda()
        args.auxiliary = False
    elif args.arch == 'resnet50':
        model = ResNet50(CIFAR_CLASSES).cuda()
        args.auxiliary = False
    elif args.arch == 'resnet34':
        model = ResNet34(CIFAR_CLASSES).cuda()
        args.auxiliary = False
    else:
        genotype = eval("genotypes.%s" % args.arch)
        model = Network(args.init_channels, CIFAR_CLASSES, args.layers,
                        args.auxiliary, genotype)
        model = model.cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    train_transform, test_transform = utils._data_transforms_cifar10(args)

    # Load dataset
    if args.dataset == 'cifar10':
        noisy_train_data = CIFAR10(root=args.data,
                                   train=True,
                                   gold=False,
                                   gold_fraction=0.0,
                                   corruption_prob=args.corruption_prob,
                                   corruption_type=args.corruption_type,
                                   transform=train_transform,
                                   download=True,
                                   seed=args.data_seed)
        gold_train_data = CIFAR10(root=args.data,
                                  train=True,
                                  gold=True,
                                  gold_fraction=1.0,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True,
                                  seed=args.data_seed)
        test_data = dset.CIFAR10(root=args.data,
                                 train=False,
                                 download=True,
                                 transform=test_transform)
    elif args.dataset == 'cifar100':
        noisy_train_data = CIFAR100(root=args.data,
                                    train=True,
                                    gold=False,
                                    gold_fraction=0.0,
                                    corruption_prob=args.corruption_prob,
                                    corruption_type=args.corruption_type,
                                    transform=train_transform,
                                    download=True,
                                    seed=args.data_seed)
        gold_train_data = CIFAR100(root=args.data,
                                   train=True,
                                   gold=True,
                                   gold_fraction=1.0,
                                   corruption_prob=args.corruption_prob,
                                   corruption_type=args.corruption_type,
                                   transform=train_transform,
                                   download=True,
                                   seed=args.data_seed)
        test_data = dset.CIFAR100(root=args.data,
                                  train=False,
                                  download=True,
                                  transform=test_transform)

    num_train = len(gold_train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    if args.gold_fraction == 1.0:
        train_data = gold_train_data
    else:
        train_data = noisy_train_data
    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=0)

    if args.clean_valid:
        valid_data = gold_train_data
    else:
        valid_data = noisy_train_data

    valid_queue = torch.utils.data.DataLoader(
        valid_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
        pin_memory=True,
        num_workers=0)

    test_queue = torch.utils.data.DataLoader(test_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss(alpha=args.alpha).cuda()
    elif args.loss_func == 'forward_gold':
        corruption_matrix = train_data.corruption_matrix
        criterion = utils.ForwardGoldLoss(corruption_matrix=corruption_matrix)
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(
            args.loss_func)

    for epoch in range(args.epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        logging.info('train_acc %f', train_acc)

        valid_acc, valid_obj = infer_valid(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)

        test_acc, test_obj = infer(test_queue, model, criterion)
        logging.info('test_acc %f', test_acc)

        utils.save(model, os.path.join(args.save, 'weights.pt'))
Exemple #6
0
def main():

    np.random.seed(args.seed)
    torch.cuda.set_device(device)
    cudnn.benchmark = True
    cudnn.enabled = True
    torch.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    train_transform, valid_transform = utils._data_transforms_cifar10(args)

    # Load dataset
    if args.dataset == 'cifar10':
        if args.gold_fraction == 0:
            train_data = CIFAR10(root=args.data,
                                 train=True,
                                 gold=False,
                                 gold_fraction=args.gold_fraction,
                                 corruption_prob=args.corruption_prob,
                                 corruption_type=args.corruption_type,
                                 transform=train_transform,
                                 download=True,
                                 seed=args.seed)
        else:
            train_data = CIFAR10(root=args.data,
                                 train=True,
                                 gold=True,
                                 gold_fraction=args.gold_fraction,
                                 corruption_prob=args.corruption_prob,
                                 corruption_type=args.corruption_type,
                                 transform=train_transform,
                                 download=True,
                                 seed=args.seed)
        valid_data = dset.CIFAR10(root=args.data,
                                  train=False,
                                  download=True,
                                  transform=valid_transform)
        num_classes = 10

    elif args.dataset == 'cifar100':
        if args.gold_fraction == 0:
            train_data = CIFAR100(root=args.data,
                                  train=True,
                                  gold=False,
                                  gold_fraction=args.gold_fraction,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True,
                                  seed=args.seed)
        else:
            train_data = CIFAR100(root=args.data,
                                  train=True,
                                  gold=True,
                                  gold_fraction=args.gold_fraction,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True,
                                  seed=args.seed)
        valid_data = dset.CIFAR100(root=args.data,
                                   train=False,
                                   download=True,
                                   transform=valid_transform)
        num_classes = 100

    genotype = eval("genotypes.%s" % args.arch)
    model = Network(args.init_ch, num_classes, args.layers, args.auxiliary,
                    genotype).cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().to(device)
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss().to(device)
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(
            args.loss_func)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.wd,
                                nesterov=True)

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batchsz,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=2)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batchsz,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))

    for epoch in range(args.epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc: %f', valid_acc)

        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        logging.info('train_acc: %f', train_acc)

        utils.save(model, os.path.join(args.save, 'trained.pt'))
        print('saved to: trained.pt')
Exemple #7
0
        args.corruption_type,
        transform=test_transform,
        download=True,
        shuffle_indices=train_data_gold.shuffle_indices)
    test_data = CIFAR10(args.data_path,
                        train=False,
                        transform=test_transform,
                        download=True)
    num_classes = 10

elif args.dataset == 'cifar100':
    train_data_gold = CIFAR100(args.data_path,
                               True,
                               True,
                               args.gold_fraction,
                               args.corruption_prob,
                               args.corruption_type,
                               transform=train_transform,
                               download=True,
                               distinguish_gold=False)
    train_data_silver = CIFAR100(
        args.data_path,
        True,
        False,
        args.gold_fraction,
        args.corruption_prob,
        args.corruption_type,
        transform=train_transform,
        download=True,
        shuffle_indices=train_data_gold.shuffle_indices)
    train_data_gold_deterministic = CIFAR100(
Exemple #8
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    model = ResNet18()
    model = model.cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss().cuda()
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(
            args.loss_func)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    train_transform, valid_transform = utils._data_transforms_cifar10(args)
    if args.dataset == 'cifar10':
        if args.gold_fraction == 0:
            train_data = CIFAR10(root=args.data,
                                 train=True,
                                 gold=False,
                                 gold_fraction=args.gold_fraction,
                                 corruption_prob=args.corruption_prob,
                                 corruption_type=args.corruption_type,
                                 transform=train_transform,
                                 download=True,
                                 seed=args.seed)
        else:
            train_data = CIFAR10(root=args.data,
                                 train=True,
                                 gold=True,
                                 gold_fraction=args.gold_fraction,
                                 corruption_prob=args.corruption_prob,
                                 corruption_type=args.corruption_type,
                                 transform=train_transform,
                                 download=True,
                                 seed=args.seed)
        valid_data = dset.CIFAR10(root=args.data,
                                  train=False,
                                  download=True,
                                  transform=valid_transform)

    elif args.dataset == 'cifar100':
        if args.gold_fraction == 0:
            train_data = CIFAR100(root=args.data,
                                  train=True,
                                  gold=False,
                                  gold_fraction=args.gold_fraction,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True,
                                  seed=args.seed)
        else:
            train_data = CIFAR100(root=args.data,
                                  train=True,
                                  gold=True,
                                  gold_fraction=args.gold_fraction,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True,
                                  seed=args.seed)
        valid_data = dset.CIFAR100(root=args.data,
                                   train=False,
                                   download=True,
                                   transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=2)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))
    for epoch in range(args.epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])

        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        logging.info('train_acc %f', train_acc)

        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)

        utils.save(model, os.path.join(args.save, 'weights.pt'))
Exemple #9
0
                                   corruption_type='unif',
                                   transform=test_transform)
        val_indices = np.random.choice(len(train_data),
                                       args.val_set_size,
                                       replace=False)
        train_indices = np.array(
            list(set(range(len(train_data))) - set(val_indices)))
        val_data = torch.utils.data.Subset(train_data_clean, val_indices)
        train_data = torch.utils.data.Subset(train_data, train_indices)
    test_data = CIFAR10(args.data_path, train=False, transform=test_transform)
    num_classes = 10
else:
    train_data = CIFAR100(args.data_path,
                          True,
                          False,
                          0,
                          corruption_prob=args.corruption_prob,
                          corruption_type='unif',
                          transform=train_transform)
    if args.val_set_size > 0:
        train_data_clean = CIFAR100(args.data_path,
                                    True,
                                    False,
                                    0,
                                    corruption_prob=0,
                                    corruption_type='unif',
                                    transform=test_transform)
        val_indices = np.random.choice(len(train_data),
                                       args.val_set_size,
                                       replace=False)
        train_indices = np.array(
def main():
    np.random.seed(args.seed)
    cudnn.benchmark = True
    cudnn.enabled = True
    torch.manual_seed(args.seed)

    # ================================================
    for id in device_ids:
        total, used = os.popen(
            'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader'
                ).read().split('\n')[id].split(',')
        print('GPU ({}) mem:'.format(id), total, 'used:', used)


    # try:
    #     block_mem = 0.85 * (total - used)
    #     print(block_mem)
    #     x = torch.empty((256, 1024, int(block_mem))).cuda()
    #     del x
    # except RuntimeError as err:
    #     print(err)
    #     block_mem = 0.8 * (total - used)
    #     print(block_mem)
    #     x = torch.empty((256, 1024, int(block_mem))).cuda()
    #     del x
    #
    #
    # print('reuse mem now ...')
    # ================================================

    args.unrolled = True


    logging.info('GPU device = %s' % args.gpu)
    logging.info("args = %s", args)

    train_transform, valid_transform = utils._data_transforms_cifar10(args)

    # Load dataset
    if args.dataset == 'cifar10':
        if args.gold_fraction == 0:
            train_data = CIFAR10(
                root=args.data, train=True, gold=False, gold_fraction=args.gold_fraction,
                corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                transform=train_transform, download=True, seed=args.seed)
            if args.clean_valid:
                gold_train_data = CIFAR10(
                    root=args.data, train=True, gold=True, gold_fraction=1.0,
                    corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                    transform=train_transform, download=True, seed=args.seed)
        else:
            train_data = CIFAR10(
                root=args.data, train=True, gold=True, gold_fraction=args.gold_fraction,
                corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                transform=train_transform, download=True, seed=args.seed)
        num_classes = 10

    elif args.dataset == 'cifar100':
        if args.gold_fraction == 0:
            train_data = CIFAR100(
                root=args.data, train=True, gold=False, gold_fraction=args.gold_fraction,
                corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                transform=train_transform, download=True, seed=args.seed)
            if args.clean_valid:
                gold_train_data = CIFAR100(
                    root=args.data, train=True, gold=True, gold_fraction=1.0,
                    corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                    transform=train_transform, download=True, seed=args.seed)
        else:
            train_data = CIFAR100(
                root=args.data, train=True, gold=True, gold_fraction=args.gold_fraction,
                corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                transform=train_transform, download=True, seed=args.seed)
        num_classes = 100

    # Split data to train and validation
    num_train = len(train_data)  # 50000
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))  # 45000

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=batchsz,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True, num_workers=2)

    if args.clean_valid:
        valid_queue = torch.utils.data.DataLoader(
            gold_train_data, batch_size=batchsz,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
            pin_memory=True, num_workers=2)
    else:
        valid_queue = torch.utils.data.DataLoader(
            train_data, batch_size=batchsz,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
            pin_memory=True, num_workers=2)

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss().cuda()
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(args.loss_func)
    model = Network(args.init_ch, num_classes, args.layers, criterion)
    if len(device_ids) > 1:
        model = MyDataParallel(model).cuda()
    else:
        model.cuda()
    # model = para_model.module.cuda()

    logging.info("Total param size = %f MB", utils.count_parameters_in_MB(model))

    # this is the optimizer to optimize
    optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)

    scheduler = optim.lr_scheduler.CosineAnnealingLR(
                    optimizer, float(args.epochs), eta_min=args.lr_min)

    arch = Arch(model, args)

    global start
    start = time.time()
    for epoch in range(args.epochs):
        current_time = time.time()
        if current_time - start >= args.time_limit:
            break
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logging.info('\nEpoch: %d lr: %e', epoch, lr)

        genotype = model.genotype()
        logging.info('Genotype: %s', genotype)

        # print(F.softmax(model.alphas_normal, dim=-1))
        # print(F.softmax(model.alphas_reduce, dim=-1))

        # training
        train_acc, train_obj = train(train_queue, valid_queue, model, arch, criterion, optimizer, lr)
        logging.info('train acc: %f', train_acc)

        # validation
        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid acc: %f', valid_acc)

        utils.save(model, os.path.join(args.exp_path, 'search_epoch1.pt'))