예제 #1
0
def build_dataset():
    normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                     std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
    if args.augment:
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
                                              (4, 4, 4, 4), mode='reflect').squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])

    if args.dataset == 'cifar10':
        train_data_meta = CIFAR10(
            root='../data', train=True, meta=True, num_meta=args.num_meta, corruption_prob=args.corruption_prob,
            corruption_type=args.corruption_type, transform=train_transform, download=True)
        train_data = CIFAR10(
            root='../data', train=True, meta=False, num_meta=args.num_meta, corruption_prob=args.corruption_prob,
            corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed)
        test_data = CIFAR10(root='../data', train=False, transform=test_transform, download=True)


    elif args.dataset == 'cifar100':
        train_data_meta = CIFAR100(
            root='../data', train=True, meta=True, num_meta=args.num_meta, corruption_prob=args.corruption_prob,
            corruption_type=args.corruption_type, transform=train_transform, download=True)
        train_data = CIFAR100(
            root='../data', train=True, meta=False, num_meta=args.num_meta, corruption_prob=args.corruption_prob,
            corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed)
        test_data = CIFAR100(root='../data', train=False, transform=test_transform, download=True)


    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True,
        num_workers=args.prefetch, pin_memory=True)
    train_meta_loader = torch.utils.data.DataLoader(
        train_data_meta, batch_size=args.batch_size, shuffle=True,
        num_workers=args.prefetch, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                                              num_workers=args.prefetch, pin_memory=True)

    return train_loader, train_meta_loader, test_loader
예제 #2
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss().cuda()
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(
            args.loss_func)

    model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
    model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    train_transform, valid_transform = utils._data_transforms_cifar10(args)

    # Load dataset
    if args.gold_fraction == 0:
        train_data = CIFAR10(root=args.data,
                             train=True,
                             gold=False,
                             gold_fraction=args.gold_fraction,
                             corruption_prob=args.corruption_prob,
                             corruption_type=args.corruption_type,
                             transform=train_transform,
                             download=True,
                             seed=args.seed)
        if args.clean_valid:
            gold_train_data = CIFAR10(root=args.data,
                                      train=True,
                                      gold=True,
                                      gold_fraction=1.0,
                                      corruption_prob=args.corruption_prob,
                                      corruption_type=args.corruption_type,
                                      transform=train_transform,
                                      download=True,
                                      seed=args.seed)
    else:
        train_data = CIFAR10(root=args.data,
                             train=True,
                             gold=True,
                             gold_fraction=args.gold_fraction,
                             corruption_prob=args.corruption_prob,
                             corruption_type=args.corruption_type,
                             transform=train_transform,
                             download=True,
                             seed=args.seed)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=2)

    if args.clean_valid:
        valid_queue = torch.utils.data.DataLoader(
            gold_train_data,
            batch_size=args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                indices[split:]),
            pin_memory=True,
            num_workers=2)
    else:
        valid_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                indices[split:]),
            pin_memory=True,
            num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    architect = Architect(model, args)

    for epoch in range(args.epochs):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)

        genotype = model.genotype()
        logging.info('genotype = %s', genotype)

        print(F.softmax(model.alphas_normal, dim=-1))
        print(F.softmax(model.alphas_reduce, dim=-1))

        # training
        train_acc, train_obj = train(train_queue, valid_queue, model,
                                     architect, criterion, optimizer, lr)
        logging.info('train_acc %f', train_acc)

        # validation
        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)

        utils.save(model, os.path.join(args.save, 'weights.pt'))
예제 #3
0
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
test_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean, std)])

if args.dataset == 'cifar10':
    train_data_gold = CIFAR10(args.data_path,
                              True,
                              True,
                              args.gold_fraction,
                              args.corruption_prob,
                              args.corruption_type,
                              transform=train_transform,
                              download=True)
    train_data_silver = CIFAR10(args.data_path,
                                True,
                                False,
                                args.gold_fraction,
                                args.corruption_prob,
                                args.corruption_type,
                                transform=train_transform,
                                download=True)
    train_data_gold_deterministic = CIFAR10(args.data_path,
                                            True,
                                            True,
                                            args.gold_fraction,
예제 #4
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = False
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    cudnn.deterministic = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss(alpha=args.alpha).cuda()
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(
            args.loss_func)

    model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
    model = model.cuda()
    model.train()
    model.apply(weights_init)
    nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    train_transform, valid_transform = utils._data_transforms_cifar10(args)

    # Load dataset
    if args.dataset == 'cifar10':
        train_data = CIFAR10(root=args.data,
                             train=True,
                             gold=False,
                             gold_fraction=0.0,
                             corruption_prob=args.corruption_prob,
                             corruption_type=args.corruption_type,
                             transform=train_transform,
                             download=True,
                             seed=args.seed)
        gold_train_data = CIFAR10(root=args.data,
                                  train=True,
                                  gold=True,
                                  gold_fraction=1.0,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True,
                                  seed=args.seed)
    elif args.dataset == 'cifar100':
        train_data = CIFAR100(root=args.data,
                              train=True,
                              gold=False,
                              gold_fraction=0.0,
                              corruption_prob=args.corruption_prob,
                              corruption_type=args.corruption_type,
                              transform=train_transform,
                              download=True,
                              seed=args.seed)
        gold_train_data = CIFAR100(root=args.data,
                                   train=True,
                                   gold=True,
                                   gold_fraction=1.0,
                                   corruption_prob=args.corruption_prob,
                                   corruption_type=args.corruption_type,
                                   transform=train_transform,
                                   download=True,
                                   seed=args.seed)
    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    clean_train_queue = torch.utils.data.DataLoader(
        gold_train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=0)
    noisy_train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=0)

    clean_valid_queue = torch.utils.data.DataLoader(
        gold_train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
        pin_memory=True,
        num_workers=0)
    noisy_valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
        pin_memory=True,
        num_workers=0)

    clean_train_list, clean_valid_list, noisy_train_list, noisy_valid_list = [], [], [], []
    for dst_list, queue in [
        (clean_train_list, clean_train_queue),
        (clean_valid_list, clean_valid_queue),
        (noisy_train_list, noisy_train_queue),
        (noisy_valid_list, noisy_valid_queue),
    ]:
        for input, target in queue:
            input = Variable(input, volatile=True).cuda()
            target = Variable(target, volatile=True).cuda(async=True)
            dst_list.append((input, target))

    for epoch in range(args.epochs):
        logging.info('Epoch %d, random architecture with fix weights', epoch)

        genotype = model.genotype()
        logging.info('genotype = %s', genotype)

        logging.info(F.softmax(model.alphas_normal, dim=-1))
        logging.info(F.softmax(model.alphas_reduce, dim=-1))

        # training
        clean_train_acc, clean_train_obj = infer(clean_train_list,
                                                 model,
                                                 criterion,
                                                 kind='clean_train')
        logging.info('clean_train_acc %f, clean_train_loss %f',
                     clean_train_acc, clean_train_obj)

        noisy_train_acc, noisy_train_obj = infer(noisy_train_list,
                                                 model,
                                                 criterion,
                                                 kind='noisy_train')
        logging.info('noisy_train_acc %f, noisy_train_loss %f',
                     noisy_train_acc, noisy_train_obj)

        # validation
        clean_valid_acc, clean_valid_obj = infer(clean_valid_list,
                                                 model,
                                                 criterion,
                                                 kind='clean_valid')
        logging.info('clean_valid_acc %f, clean_valid_loss %f',
                     clean_valid_acc, clean_valid_obj)

        noisy_valid_acc, noisy_valid_obj = infer(noisy_valid_list,
                                                 model,
                                                 criterion,
                                                 kind='noisy_valid')
        logging.info('noisy_valid_acc %f, noisy_valid_loss %f',
                     noisy_valid_acc, noisy_valid_obj)

        utils.save(model, os.path.join(args.save, 'weights.pt'))

        # Randomly change the alphas
        k = sum(1 for i in range(model._steps) for n in range(2 + i))
        num_ops = len(PRIMITIVES)
        model.alphas_normal.data.copy_(torch.randn(k, num_ops))
        model.alphas_reduce.data.copy_(torch.randn(k, num_ops))
예제 #5
0
파일: train.py 프로젝트: YIWEI-CHEN/darts
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    random.seed(args.seed)
    np.random.seed(
        args.data_seed)  # cutout and load_corrupted_data use np.random
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = False
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    cudnn.deterministic = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    if args.arch == 'resnet':
        model = ResNet18(CIFAR_CLASSES).cuda()
        args.auxiliary = False
    elif args.arch == 'resnet50':
        model = ResNet50(CIFAR_CLASSES).cuda()
        args.auxiliary = False
    elif args.arch == 'resnet34':
        model = ResNet34(CIFAR_CLASSES).cuda()
        args.auxiliary = False
    else:
        genotype = eval("genotypes.%s" % args.arch)
        model = Network(args.init_channels, CIFAR_CLASSES, args.layers,
                        args.auxiliary, genotype)
        model = model.cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    train_transform, test_transform = utils._data_transforms_cifar10(args)

    # Load dataset
    if args.dataset == 'cifar10':
        noisy_train_data = CIFAR10(root=args.data,
                                   train=True,
                                   gold=False,
                                   gold_fraction=0.0,
                                   corruption_prob=args.corruption_prob,
                                   corruption_type=args.corruption_type,
                                   transform=train_transform,
                                   download=True,
                                   seed=args.data_seed)
        gold_train_data = CIFAR10(root=args.data,
                                  train=True,
                                  gold=True,
                                  gold_fraction=1.0,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True,
                                  seed=args.data_seed)
        test_data = dset.CIFAR10(root=args.data,
                                 train=False,
                                 download=True,
                                 transform=test_transform)
    elif args.dataset == 'cifar100':
        noisy_train_data = CIFAR100(root=args.data,
                                    train=True,
                                    gold=False,
                                    gold_fraction=0.0,
                                    corruption_prob=args.corruption_prob,
                                    corruption_type=args.corruption_type,
                                    transform=train_transform,
                                    download=True,
                                    seed=args.data_seed)
        gold_train_data = CIFAR100(root=args.data,
                                   train=True,
                                   gold=True,
                                   gold_fraction=1.0,
                                   corruption_prob=args.corruption_prob,
                                   corruption_type=args.corruption_type,
                                   transform=train_transform,
                                   download=True,
                                   seed=args.data_seed)
        test_data = dset.CIFAR100(root=args.data,
                                  train=False,
                                  download=True,
                                  transform=test_transform)

    num_train = len(gold_train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    if args.gold_fraction == 1.0:
        train_data = gold_train_data
    else:
        train_data = noisy_train_data
    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=0)

    if args.clean_valid:
        valid_data = gold_train_data
    else:
        valid_data = noisy_train_data

    valid_queue = torch.utils.data.DataLoader(
        valid_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
        pin_memory=True,
        num_workers=0)

    test_queue = torch.utils.data.DataLoader(test_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss(alpha=args.alpha).cuda()
    elif args.loss_func == 'forward_gold':
        corruption_matrix = train_data.corruption_matrix
        criterion = utils.ForwardGoldLoss(corruption_matrix=corruption_matrix)
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(
            args.loss_func)

    for epoch in range(args.epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        logging.info('train_acc %f', train_acc)

        valid_acc, valid_obj = infer_valid(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)

        test_acc, test_obj = infer(test_queue, model, criterion)
        logging.info('test_acc %f', test_acc)

        utils.save(model, os.path.join(args.save, 'weights.pt'))
예제 #6
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)
    cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    random.seed(args.seed)

    if args.arch == 'resnet':
        model = ResNet18(CIFAR_CLASSES).cuda()
        args.auxiliary = False
    elif args.arch == 'resnet50':
        model = ResNet50(CIFAR_CLASSES).cuda()
        args.auxiliary = False
    elif args.arch == 'resnet34':
        model = ResNet34(CIFAR_CLASSES).cuda()
        args.auxiliary = False
    else:
        genotype = eval("genotypes.%s" % args.arch)
        model = Network(args.init_channels, CIFAR_CLASSES, args.layers,
                        args.auxiliary, genotype)
        model = model.cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    train_transform, test_transform = utils._data_transforms_cifar10(args)

    train_data = CIFAR10(root=args.data,
                         train=True,
                         gold=False,
                         gold_fraction=0.0,
                         corruption_prob=args.corruption_prob,
                         corruption_type=args.corruption_type,
                         transform=train_transform,
                         download=True,
                         seed=args.seed)
    gold_train_data = CIFAR10(root=args.data,
                              train=True,
                              gold=True,
                              gold_fraction=1.0,
                              corruption_prob=args.corruption_prob,
                              corruption_type=args.corruption_type,
                              transform=train_transform,
                              download=True,
                              seed=args.seed)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    clean_train_queue = torch.utils.data.DataLoader(
        gold_train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=2)

    noisy_train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=2)

    # clean_train_list = []
    # for input, target in clean_train_queue:
    #   input = Variable(input).cuda()
    #   target = Variable(target).cuda(async=True)
    #   clean_train_list.append((input, target))
    #
    # noisy_train_list = []
    # for input, target in noisy_train_queue:
    #   input = Variable(input).cuda()
    #   target = Variable(target).cuda(async=True)
    #   noisy_train_list.append((input, target))

    clean_valid_queue = torch.utils.data.DataLoader(
        gold_train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
        pin_memory=True,
        num_workers=2)

    noisy_valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
        pin_memory=True,
        num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss(alpha=args.alpha).cuda()
    elif args.loss_func == 'forward_gold':
        corruption_matrix = train_data.corruption_matrix
        criterion = utils.ForwardGoldLoss(corruption_matrix=corruption_matrix)
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(
            args.loss_func)

    global gain
    for epoch in range(args.epochs):
        if args.random_weight:
            logging.info('Epoch %d, Randomly assign weights', epoch)
            model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
            clean_obj, noisy_obj = infer_random_weight(clean_train_queue,
                                                       noisy_train_queue,
                                                       model, criterion)
            logging.info('clean loss %f, noisy loss %f', clean_obj, noisy_obj)
            gain = np.random.randint(1, args.grad_clip, size=1)[0]
        else:
            scheduler.step()
            logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
            model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

            train_acc, train_obj, another_obj = train(clean_train_queue,
                                                      noisy_train_queue, model,
                                                      criterion, optimizer)
            if args.clean_train:
                logging.info('train_acc %f, clean_loss %f, noisy_loss %f',
                             train_acc, train_obj, another_obj)
            else:
                logging.info('train_acc %f, clean_loss %f, noisy_loss %f',
                             train_acc, another_obj, train_obj)

            utils.save(model, os.path.join(args.save, 'weights.pt'))

        clean_valid_acc, clean_valid_obj = infer_valid(clean_valid_queue,
                                                       model, criterion)
        logging.info('clean_valid_acc %f, clean_valid_loss %f',
                     clean_valid_acc, clean_valid_obj)

        noisy_valid_acc, noisy_valid_obj = infer_valid(noisy_valid_queue,
                                                       model, criterion)
        logging.info('noisy_valid_acc %f, noisy_valid_loss %f',
                     noisy_valid_acc, noisy_valid_obj)
예제 #7
0
def main():

    np.random.seed(args.seed)
    torch.cuda.set_device(device)
    cudnn.benchmark = True
    cudnn.enabled = True
    torch.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    train_transform, valid_transform = utils._data_transforms_cifar10(args)

    # Load dataset
    if args.dataset == 'cifar10':
        if args.gold_fraction == 0:
            train_data = CIFAR10(root=args.data,
                                 train=True,
                                 gold=False,
                                 gold_fraction=args.gold_fraction,
                                 corruption_prob=args.corruption_prob,
                                 corruption_type=args.corruption_type,
                                 transform=train_transform,
                                 download=True,
                                 seed=args.seed)
        else:
            train_data = CIFAR10(root=args.data,
                                 train=True,
                                 gold=True,
                                 gold_fraction=args.gold_fraction,
                                 corruption_prob=args.corruption_prob,
                                 corruption_type=args.corruption_type,
                                 transform=train_transform,
                                 download=True,
                                 seed=args.seed)
        valid_data = dset.CIFAR10(root=args.data,
                                  train=False,
                                  download=True,
                                  transform=valid_transform)
        num_classes = 10

    elif args.dataset == 'cifar100':
        if args.gold_fraction == 0:
            train_data = CIFAR100(root=args.data,
                                  train=True,
                                  gold=False,
                                  gold_fraction=args.gold_fraction,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True,
                                  seed=args.seed)
        else:
            train_data = CIFAR100(root=args.data,
                                  train=True,
                                  gold=True,
                                  gold_fraction=args.gold_fraction,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True,
                                  seed=args.seed)
        valid_data = dset.CIFAR100(root=args.data,
                                   train=False,
                                   download=True,
                                   transform=valid_transform)
        num_classes = 100

    genotype = eval("genotypes.%s" % args.arch)
    model = Network(args.init_ch, num_classes, args.layers, args.auxiliary,
                    genotype).cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().to(device)
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss().to(device)
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(
            args.loss_func)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.wd,
                                nesterov=True)

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batchsz,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=2)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batchsz,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))

    for epoch in range(args.epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc: %f', valid_acc)

        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        logging.info('train_acc: %f', train_acc)

        utils.save(model, os.path.join(args.save, 'trained.pt'))
        print('saved to: trained.pt')
예제 #8
0
def build_dataset_continual(args):
    kwargs = {'num_workers': 0, 'pin_memory': True}
    # assert (args.dataset == 'cifar10' or args.dataset == 'cifar100')
    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
    if args.augment:
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4),
                                              mode='reflect').squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    test_transform = transforms.Compose([transforms.ToTensor(), normalize])

    if args.dataset == 'cifar10':
        train_data_meta = CIFAR10(root='../data',
                                  train=True,
                                  meta=True,
                                  num_meta=args.num_meta,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True)
        # train_data = CIFAR10(
        #     root='../data', train=True, meta=False, num_meta=args.num_meta, corruption_prob=args.corruption_prob,
        #     corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed)
        train_data = torchvision.datasets.CIFAR10(root="./data",
                                                  download=True,
                                                  train=True,
                                                  transform=train_transform)
        test_data = CIFAR10(root='../data',
                            train=False,
                            transform=test_transform,
                            download=True)

    train_loaders = []
    CLASSES_PER_TASK = 2
    NumLoaders = 10 // CLASSES_PER_TASK
    targets = np.array(train_data.targets)
    for i in range(0, 10, NumLoaders):
        labels = np.array([], dtype=int)
        for j in range(i, i + NumLoaders):
            labels = np.concatenate((labels, np.where(np.array(targets) == j)),
                                    axis=None)

        train_dataset = torch.utils.data.Subset(train_data, labels)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.prefetch,
                                                   pin_memory=True)
        train_loaders.append(train_loader)

    train_meta_loader = torch.utils.data.DataLoader(train_data_meta,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=args.prefetch,
                                                    pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.prefetch,
                                              pin_memory=True)

    return train_loaders, train_meta_loader, test_loader
예제 #9
0
파일: train_forward.py 프로젝트: sanlb/glc
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
test_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean, std)])

if args.dataset == 'cifar10':
    train_data_gold = CIFAR10(args.data_path,
                              True,
                              True,
                              args.gold_fraction,
                              args.corruption_prob,
                              args.corruption_type,
                              transform=train_transform,
                              download=True,
                              distinguish_gold=False)
    train_data_silver = CIFAR10(
        args.data_path,
        True,
        False,
        args.gold_fraction,
        args.corruption_prob,
        args.corruption_type,
        transform=train_transform,
        download=True,
        shuffle_indices=train_data_gold.shuffle_indices)
    train_data_gold_deterministic = CIFAR10(
예제 #10
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    model = ResNet18()
    model = model.cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss().cuda()
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(
            args.loss_func)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    train_transform, valid_transform = utils._data_transforms_cifar10(args)
    if args.dataset == 'cifar10':
        if args.gold_fraction == 0:
            train_data = CIFAR10(root=args.data,
                                 train=True,
                                 gold=False,
                                 gold_fraction=args.gold_fraction,
                                 corruption_prob=args.corruption_prob,
                                 corruption_type=args.corruption_type,
                                 transform=train_transform,
                                 download=True,
                                 seed=args.seed)
        else:
            train_data = CIFAR10(root=args.data,
                                 train=True,
                                 gold=True,
                                 gold_fraction=args.gold_fraction,
                                 corruption_prob=args.corruption_prob,
                                 corruption_type=args.corruption_type,
                                 transform=train_transform,
                                 download=True,
                                 seed=args.seed)
        valid_data = dset.CIFAR10(root=args.data,
                                  train=False,
                                  download=True,
                                  transform=valid_transform)

    elif args.dataset == 'cifar100':
        if args.gold_fraction == 0:
            train_data = CIFAR100(root=args.data,
                                  train=True,
                                  gold=False,
                                  gold_fraction=args.gold_fraction,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True,
                                  seed=args.seed)
        else:
            train_data = CIFAR100(root=args.data,
                                  train=True,
                                  gold=True,
                                  gold_fraction=args.gold_fraction,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True,
                                  seed=args.seed)
        valid_data = dset.CIFAR100(root=args.data,
                                   train=False,
                                   download=True,
                                   transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=2)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))
    for epoch in range(args.epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])

        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        logging.info('train_acc %f', train_acc)

        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)

        utils.save(model, os.path.join(args.save, 'weights.pt'))
예제 #11
0
# # mean and standard deviation of channels of CIFAR-10 images
# mean = [x / 255 for x in [125.3, 123.0, 113.9]]
# std = [x / 255 for x in [63.0, 62.1, 66.7]]

train_transform = trn.Compose([
    trn.RandomHorizontalFlip(),
    trn.RandomCrop(32, padding=4),
    trn.ToTensor()
])
test_transform = trn.Compose([trn.ToTensor()])

if args.dataset == 'cifar10':
    train_data = CIFAR10(args.data_path,
                         True,
                         False,
                         0,
                         corruption_prob=args.corruption_prob,
                         corruption_type='unif',
                         transform=train_transform)
    if args.val_set_size > 0:
        train_data_clean = CIFAR10(args.data_path,
                                   True,
                                   False,
                                   0,
                                   corruption_prob=0,
                                   corruption_type='unif',
                                   transform=test_transform)
        val_indices = np.random.choice(len(train_data),
                                       args.val_set_size,
                                       replace=False)
        train_indices = np.array(
예제 #12
0
def main():
    np.random.seed(args.seed)
    cudnn.benchmark = True
    cudnn.enabled = True
    torch.manual_seed(args.seed)

    # ================================================
    for id in device_ids:
        total, used = os.popen(
            'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader'
                ).read().split('\n')[id].split(',')
        print('GPU ({}) mem:'.format(id), total, 'used:', used)


    # try:
    #     block_mem = 0.85 * (total - used)
    #     print(block_mem)
    #     x = torch.empty((256, 1024, int(block_mem))).cuda()
    #     del x
    # except RuntimeError as err:
    #     print(err)
    #     block_mem = 0.8 * (total - used)
    #     print(block_mem)
    #     x = torch.empty((256, 1024, int(block_mem))).cuda()
    #     del x
    #
    #
    # print('reuse mem now ...')
    # ================================================

    args.unrolled = True


    logging.info('GPU device = %s' % args.gpu)
    logging.info("args = %s", args)

    train_transform, valid_transform = utils._data_transforms_cifar10(args)

    # Load dataset
    if args.dataset == 'cifar10':
        if args.gold_fraction == 0:
            train_data = CIFAR10(
                root=args.data, train=True, gold=False, gold_fraction=args.gold_fraction,
                corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                transform=train_transform, download=True, seed=args.seed)
            if args.clean_valid:
                gold_train_data = CIFAR10(
                    root=args.data, train=True, gold=True, gold_fraction=1.0,
                    corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                    transform=train_transform, download=True, seed=args.seed)
        else:
            train_data = CIFAR10(
                root=args.data, train=True, gold=True, gold_fraction=args.gold_fraction,
                corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                transform=train_transform, download=True, seed=args.seed)
        num_classes = 10

    elif args.dataset == 'cifar100':
        if args.gold_fraction == 0:
            train_data = CIFAR100(
                root=args.data, train=True, gold=False, gold_fraction=args.gold_fraction,
                corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                transform=train_transform, download=True, seed=args.seed)
            if args.clean_valid:
                gold_train_data = CIFAR100(
                    root=args.data, train=True, gold=True, gold_fraction=1.0,
                    corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                    transform=train_transform, download=True, seed=args.seed)
        else:
            train_data = CIFAR100(
                root=args.data, train=True, gold=True, gold_fraction=args.gold_fraction,
                corruption_prob=args.corruption_prob, corruption_type=args.corruption_type,
                transform=train_transform, download=True, seed=args.seed)
        num_classes = 100

    # Split data to train and validation
    num_train = len(train_data)  # 50000
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))  # 45000

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=batchsz,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True, num_workers=2)

    if args.clean_valid:
        valid_queue = torch.utils.data.DataLoader(
            gold_train_data, batch_size=batchsz,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
            pin_memory=True, num_workers=2)
    else:
        valid_queue = torch.utils.data.DataLoader(
            train_data, batch_size=batchsz,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
            pin_memory=True, num_workers=2)

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss().cuda()
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(args.loss_func)
    model = Network(args.init_ch, num_classes, args.layers, criterion)
    if len(device_ids) > 1:
        model = MyDataParallel(model).cuda()
    else:
        model.cuda()
    # model = para_model.module.cuda()

    logging.info("Total param size = %f MB", utils.count_parameters_in_MB(model))

    # this is the optimizer to optimize
    optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)

    scheduler = optim.lr_scheduler.CosineAnnealingLR(
                    optimizer, float(args.epochs), eta_min=args.lr_min)

    arch = Arch(model, args)

    global start
    start = time.time()
    for epoch in range(args.epochs):
        current_time = time.time()
        if current_time - start >= args.time_limit:
            break
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logging.info('\nEpoch: %d lr: %e', epoch, lr)

        genotype = model.genotype()
        logging.info('Genotype: %s', genotype)

        # print(F.softmax(model.alphas_normal, dim=-1))
        # print(F.softmax(model.alphas_reduce, dim=-1))

        # training
        train_acc, train_obj = train(train_queue, valid_queue, model, arch, criterion, optimizer, lr)
        logging.info('train acc: %f', train_acc)

        # validation
        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid acc: %f', valid_acc)

        utils.save(model, os.path.join(args.exp_path, 'search_epoch1.pt'))