Beispiel #1
0
                                               shuffle=True,
                                               num_workers=args.prefetch,
                                               pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=args.test_bs,
                                          shuffle=False,
                                          num_workers=args.prefetch,
                                          pin_memory=True)

# Init checkpoints
if not os.path.isdir(args.save):
    os.makedirs(args.save)

# Init model, criterion, and optimizer
net = wrn.WideResNet(args.layers,
                     num_classes,
                     args.widen_factor,
                     dropRate=args.droprate)
print(net)

if args.ngpu > 1:
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

if args.ngpu > 0:
    net.cuda()

torch.manual_seed(1)
if args.ngpu > 0:
    torch.cuda.manual_seed(1)

optimizer = torch.optim.SGD(net.parameters(),
                            state['learning_rate'],
def main():
    global args, best_prec1
    args = parser.parse_args()
    if args.tensorboard: configure("runs/%s" % (args.name))

    # Data loading code
    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

    if args.augment:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    kwargs = {'num_workers': 1, 'pin_memory': True}
    assert (args.dataset == 'cifar10' or args.dataset == 'cifar100')
    train_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data',
                                                train=True,
                                                download=True,
                                                transform=transform_train),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs)
    val_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data',
                                                train=False,
                                                transform=transform_test),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs)

    # create model
    model = wrn.WideResNet(args.layers,
                           args.dataset == 'cifar10' and 10 or 100,
                           args.widen_factor,
                           dropRate=args.droprate)

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=args.nesterov,
                                weight_decay=args.weight_decay)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)
    print 'Best accuracy: ', best_prec1
Beispiel #3
0
    def __init__(self, root='', train=True, meta=True, num_meta=1000,
                 corruption_prob=0, corruption_type='unif', transform=None, target_transform=None,
                 download=False, seed=1):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set
        self.meta = meta
        self.corruption_prob = corruption_prob
        self.num_meta = num_meta

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        # now load the picked numpy arrays
        if self.train:
            self.train_data = []
            self.train_labels = []
            self.train_coarse_labels = []
            for fentry in self.train_list:
                f = fentry[0]
                file = os.path.join(root, self.base_folder, f)
                fo = open(file, 'rb')
                if sys.version_info[0] == 2:
                    entry = pickle.load(fo)
                else:
                    entry = pickle.load(fo, encoding='latin1')
                self.train_data.append(entry['data'])
                if 'labels' in entry:
                    self.train_labels += entry['labels']
                    img_num_list = [int(self.num_meta/10)] * 10
                    num_classes = 10
                else:
                    self.train_labels += entry['fine_labels']
                    self.train_coarse_labels += entry['coarse_labels']
                    img_num_list = [int(self.num_meta/100)] * 100
                    num_classes = 100
                fo.close()

            self.train_data = np.concatenate(self.train_data)
            self.train_data = self.train_data.reshape((50000, 3, 32, 32))
            self.train_data = self.train_data.transpose((0, 2, 3, 1))   # convert to HWC

            data_list_val = {}
            for j in range(num_classes):
                data_list_val[j] = [i for i, label in enumerate(self.train_labels) if label == j]


            idx_to_meta = []
            idx_to_train = []
            print(img_num_list)

            for cls_idx, img_id_list in data_list_val.items():
                np.random.shuffle(img_id_list)
                img_num = img_num_list[int(cls_idx)]
                idx_to_meta.extend(img_id_list[:img_num])
                idx_to_train.extend(img_id_list[img_num:])


            if meta is True:
                self.train_data = self.train_data[idx_to_meta]
                self.train_labels = list(np.array(self.train_labels)[idx_to_meta])
            else:
                self.train_data = self.train_data[idx_to_train]
                self.train_labels = list(np.array(self.train_labels)[idx_to_train])
                if corruption_type == 'hierarchical':
                    self.train_coarse_labels = list(np.array(self.train_coarse_labels)[idx_to_meta])

                if corruption_type == 'unif':
                    C = uniform_mix_C(self.corruption_prob, num_classes)
                    print(C)
                    self.C = C
                elif corruption_type == 'flip':
                    C = flip_labels_C(self.corruption_prob, num_classes)
                    print(C)
                    self.C = C
                elif corruption_type == 'flip2':
                    C = flip_labels_C_two(self.corruption_prob, num_classes)
                    print(C)
                    self.C = C
                elif corruption_type == 'hierarchical':
                    assert num_classes == 100, 'You must use CIFAR-100 with the hierarchical corruption.'
                    coarse_fine = []
                    for i in range(20):
                        coarse_fine.append(set())
                    for i in range(len(self.train_labels)):
                        coarse_fine[self.train_coarse_labels[i]].add(self.train_labels[i])
                    for i in range(20):
                        coarse_fine[i] = list(coarse_fine[i])

                    C = np.eye(num_classes) * (1 - corruption_prob)

                    for i in range(20):
                        tmp = np.copy(coarse_fine[i])
                        for j in range(len(tmp)):
                            tmp2 = np.delete(np.copy(tmp), j)
                            C[tmp[j], tmp2] += corruption_prob * 1/len(tmp2)
                    self.C = C
                    print(C)
                elif corruption_type == 'clabels':
                    net = wrn.WideResNet(40, num_classes, 2, dropRate=0.3).cuda()
                    model_name = './cifar{}_labeler'.format(num_classes)
                    net.load_state_dict(torch.load(model_name))
                    net.eval()
                else:
                    assert False, "Invalid corruption type '{}' given. Must be in {'unif', 'flip', 'hierarchical'}".format(corruption_type)

                np.random.seed(seed)
                if corruption_type == 'clabels':
                    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
                    std = [x / 255 for x in [63.0, 62.1, 66.7]]

                    test_transform = transforms.Compose(
                        [transforms.ToTensor(), transforms.Normalize(mean, std)])

                    # obtain sampling probabilities
                    sampling_probs = []
                    print('Starting labeling')

                    for i in range((len(self.train_labels) // 64) + 1):
                        current = self.train_data[i*64:(i+1)*64]
                        current = [Image.fromarray(current[i]) for i in range(len(current))]
                        current = torch.cat([test_transform(current[i]).unsqueeze(0) for i in range(len(current))], dim=0)

                        data = V(current).cuda()
                        logits = net(data)
                        smax = F.softmax(logits / 5)  # temperature of 1
                        sampling_probs.append(smax.data.cpu().numpy())


                    sampling_probs = np.concatenate(sampling_probs, 0)
                    print('Finished labeling 1')

                    new_labeling_correct = 0
                    argmax_labeling_correct = 0
                    for i in range(len(self.train_labels)):
                        old_label = self.train_labels[i]
                        new_label = np.random.choice(num_classes, p=sampling_probs[i])
                        self.train_labels[i] = new_label
                        if old_label == new_label:
                            new_labeling_correct += 1
                        if old_label == np.argmax(sampling_probs[i]):
                            argmax_labeling_correct += 1
                    print('Finished labeling 2')
                    print('New labeling accuracy:', new_labeling_correct / len(self.train_labels))
                    print('Argmax labeling accuracy:', argmax_labeling_correct / len(self.train_labels))
                else:    
                    for i in range(len(self.train_labels)):
                        self.train_labels[i] = np.random.choice(num_classes, p=C[self.train_labels[i]])
                    self.corruption_matrix = C

        else:
            f = self.test_list[0][0]
            file = os.path.join(root, self.base_folder, f)
            fo = open(file, 'rb')
            if sys.version_info[0] == 2:
                entry = pickle.load(fo)
            else:
                entry = pickle.load(fo, encoding='latin1')
            self.test_data = entry['data']
            if 'labels' in entry:
                self.test_labels = entry['labels']
            else:
                self.test_labels = entry['fine_labels']
            fo.close()
            self.test_data = self.test_data.reshape((10000, 3, 32, 32))
            self.test_data = self.test_data.transpose((0, 2, 3, 1))  # convert to HWC
def main():
    global args, best_prec1, best_test_prec1
    global acc1_tr, losses_tr
    global losses_cl_tr
    global acc1_val, losses_val, losses_et_val
    global acc1_test, losses_test, losses_et_test
    global weights_cl
    args = parser.parse_args()
    print(args)
    if args.dataset == 'svhn':
        drop_rate = 0.3
        widen_factor = 3
    else:
        drop_rate = 0.3
        widen_factor = 3

    # create model
    if args.arch == 'preresnet':
        print("Model: %s" % args.arch)
        model = preresnet_cifar.resnet(depth=32, num_classes=args.num_classes)
    elif args.arch == 'wideresnet':
        print("Model: %s" % args.arch)
        model = wideresnet.WideResNet(28,
                                      args.num_classes,
                                      widen_factor=widen_factor,
                                      dropRate=drop_rate,
                                      leakyRate=0.1)
    else:
        assert (False)

    if args.model == 'mt':
        import copy
        model_teacher = copy.deepcopy(model)
        model_teacher = torch.nn.DataParallel(model_teacher).cuda()

    model = torch.nn.DataParallel(model).cuda()
    print(model)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']

            model.load_state_dict(checkpoint['state_dict'])
            if args.model == 'mt':
                model_teacher.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.optim == 'sgd' or args.optim == 'adam':
        pass
    else:
        print('Not Implemented Optimizer')
        assert (False)

    ckpt_dir = args.ckpt + '_' + args.dataset + '_' + args.arch + '_' + args.model + '_' + args.optim
    ckpt_dir = ckpt_dir + '_e%d' % (args.epochs)
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    print(ckpt_dir)
    cudnn.benchmark = True

    # Data loading code
    if args.dataset == 'cifar10':
        dataloader = cifar.CIFAR10
        num_classes = 10
        data_dir = '/tmp/'

        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                         std=[0.2023, 0.1994, 0.2010])
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])

    elif args.dataset == 'cifar10_zca':
        dataloader = cifar_zca.CIFAR10
        num_classes = 10
        data_dir = 'cifar10_zca/cifar10_gcn_zca_v2.npz'

        # transform is implemented inside zca dataloader
        transform_train = transforms.Compose([
            transforms.ToTensor(),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
        ])

    elif args.dataset == 'svhn':
        dataloader = svhn.SVHN
        num_classes = 10
        data_dir = '/tmp/'

        normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                         std=[0.5, 0.5, 0.5])
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=2),
            transforms.ToTensor(),
            normalize,
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])

    labelset = dataloader(root=data_dir,
                          split='label',
                          download=True,
                          transform=transform_train,
                          boundary=args.boundary)
    unlabelset = dataloader(root=data_dir,
                            split='unlabel',
                            download=True,
                            transform=transform_train,
                            boundary=args.boundary)
    batch_size_label = args.batch_size // 2
    batch_size_unlabel = args.batch_size // 2
    if args.model == 'baseline': batch_size_label = args.batch_size

    label_loader = data.DataLoader(labelset,
                                   batch_size=batch_size_label,
                                   shuffle=True,
                                   num_workers=args.workers,
                                   pin_memory=True)
    label_iter = iter(label_loader)

    unlabel_loader = data.DataLoader(unlabelset,
                                     batch_size=batch_size_unlabel,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True)
    unlabel_iter = iter(unlabel_loader)

    print("Batch size (label): ", batch_size_label)
    print("Batch size (unlabel): ", batch_size_unlabel)

    validset = dataloader(root=data_dir,
                          split='valid',
                          download=True,
                          transform=transform_test,
                          boundary=args.boundary)
    val_loader = data.DataLoader(validset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 pin_memory=True)

    testset = dataloader(root=data_dir,
                         split='test',
                         download=True,
                         transform=transform_test)
    test_loader = data.DataLoader(testset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.workers,
                                  pin_memory=True)

    # deifine loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss(size_average=False).cuda()
    criterion_mse = nn.MSELoss(size_average=False).cuda()
    criterion_kl = nn.KLDivLoss(size_average=False).cuda()
    criterion_l1 = nn.L1Loss(size_average=False).cuda()

    criterions = (criterion, criterion_mse, criterion_kl, criterion_l1)

    if args.optim == 'adam':
        print('Using Adam optimizer')
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     betas=(0.9, 0.999),
                                     weight_decay=args.weight_decay)
    elif args.optim == 'sgd':
        print('Using SGD optimizer')
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    for epoch in range(args.start_epoch, args.epochs):
        if args.optim == 'adam':
            print('Learning rate schedule for Adam')
            lr = adjust_learning_rate_adam(optimizer, epoch)
        elif args.optim == 'sgd':
            print('Learning rate schedule for SGD')
            lr = adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        if args.model == 'baseline':
            print('Supervised Training')
            for i in range(
                    10
            ):  #baseline repeat 10 times since small number of training set
                prec1_tr, loss_tr = train_sup(label_loader, model, criterions,
                                              optimizer, epoch, args)
                weight_cl = 0.0
        elif args.model == 'pi':
            print('Pi model')
            prec1_tr, loss_tr, loss_cl_tr, weight_cl = train_pi(
                label_loader, unlabel_loader, model, criterions, optimizer,
                epoch, args)
        elif args.model == 'mt':
            print('Mean Teacher model')
            prec1_tr, loss_tr, loss_cl_tr, prec1_t_tr, weight_cl = train_mt(
                label_loader, unlabel_loader, model, model_teacher, criterions,
                optimizer, epoch, args)
        else:
            print("Not Implemented ", args.model)
            assert (False)

        # evaluate on validation set
        prec1_val, loss_val = validate(val_loader, model, criterions, args,
                                       'valid')
        prec1_test, loss_test = validate(test_loader, model, criterions, args,
                                         'test')
        if args.model == 'mt':
            prec1_t_val, loss_t_val = validate(val_loader, model_teacher,
                                               criterions, args, 'valid')
            prec1_t_test, loss_t_test = validate(test_loader, model_teacher,
                                                 criterions, args, 'test')

        # append values
        acc1_tr.append(prec1_tr)
        losses_tr.append(loss_tr)
        acc1_val.append(prec1_val)
        losses_val.append(loss_val)
        acc1_test.append(prec1_test)
        losses_test.append(loss_test)
        if args.model != 'baseline':
            losses_cl_tr.append(loss_cl_tr)
        if args.model == 'mt':
            acc1_t_tr.append(prec1_t_tr)
            acc1_t_val.append(prec1_t_val)
            acc1_t_test.append(prec1_t_test)
        weights_cl.append(weight_cl)
        learning_rate.append(lr)

        # remember best prec@1 and save checkpoint
        if args.model == 'mt':
            is_best = prec1_t_val > best_prec1
            if is_best:
                best_test_prec1_t = prec1_t_test
                best_test_prec1 = prec1_test
            print("Best test precision: %.3f" % best_test_prec1_t)
            best_prec1 = max(prec1_t_val, best_prec1)
            dict_checkpoint = {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'best_test_prec1': best_test_prec1,
                'acc1_tr': acc1_tr,
                'losses_tr': losses_tr,
                'losses_cl_tr': losses_cl_tr,
                'acc1_val': acc1_val,
                'losses_val': losses_val,
                'acc1_test': acc1_test,
                'losses_test': losses_test,
                'acc1_t_tr': acc1_t_tr,
                'acc1_t_val': acc1_t_val,
                'acc1_t_test': acc1_t_test,
                'state_dict_teacher': model_teacher.state_dict(),
                'best_test_prec1_t': best_test_prec1_t,
                'weights_cl': weights_cl,
                'learning_rate': learning_rate,
            }

        else:
            is_best = prec1_val > best_prec1
            if is_best:
                best_test_prec1 = prec1_test
            print("Best test precision: %.3f" % best_test_prec1)
            best_prec1 = max(prec1_val, best_prec1)
            dict_checkpoint = {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'best_test_prec1': best_test_prec1,
                'acc1_tr': acc1_tr,
                'losses_tr': losses_tr,
                'losses_cl_tr': losses_cl_tr,
                'acc1_val': acc1_val,
                'losses_val': losses_val,
                'acc1_test': acc1_test,
                'losses_test': losses_test,
                'weights_cl': weights_cl,
                'learning_rate': learning_rate,
            }

        save_checkpoint(dict_checkpoint,
                        is_best,
                        args.arch.lower() + str(args.boundary),
                        dirname=ckpt_dir)