Пример #1
0
    def setup(self, flags):
        torch.backends.cudnn.deterministic = flags.deterministic
        print('torch.backends.cudnn.deterministic:', torch.backends.cudnn.deterministic)
        fix_all_seed(flags.seed)

        if flags.dataset == 'cifar10':
            num_classes = 10
        else:
            num_classes = 100

        if flags.model == 'densenet':
            self.network = densenet(num_classes=num_classes)
        elif flags.model == 'wrn':
            self.network = WideResNet(flags.layers, num_classes, flags.widen_factor, flags.droprate)
        elif flags.model == 'allconv':
            self.network = AllConvNet(num_classes)
        elif flags.model == 'resnext':
            self.network = resnext29(num_classes=num_classes)
        else:
            raise Exception('Unknown model.')
        self.network = self.network.cuda()

        print(self.network)
        print('flags:', flags)
        if not os.path.exists(flags.logs):
            os.makedirs(flags.logs)

        flags_log = os.path.join(flags.logs, 'flags_log.txt')
        write_log(flags, flags_log)
Пример #2
0
    num_classes = 100

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           num_workers=args.prefetch,
                                           pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=args.test_bs,
                                          shuffle=False,
                                          num_workers=args.prefetch,
                                          pin_memory=True)

# Create model
if args.model == 'allconv':
    net = AllConvNet(num_classes).cuda()
# else:
#     net = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate)

start_epoch = 0

cudnn.benchmark = True  # fire on all cylinders

optimizer = torch.optim.SGD(net.parameters(),
                            state['learning_rate'],
                            momentum=state['momentum'],
                            weight_decay=state['decay'],
                            nesterov=True)


def cosine_annealing(step, total_steps, lr_max, lr_min):
Пример #3
0
num_classes = 1000

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           num_workers=args.prefetch,
                                           pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=args.test_bs,
                                          shuffle=False,
                                          num_workers=args.prefetch,
                                          pin_memory=True)

# Create model
if args.model == 'allconv':
    net = AllConvNet(num_classes)
else:
    net = WideResNet(args.layers,
                     num_classes,
                     args.widen_factor,
                     dropRate=args.droprate)

start_epoch = 0

# Restore model if desired
if args.load != '':
    for i in range(1000 - 1, -1, -1):
        model_name = os.path.join(
            args.load, args.dataset + '_' + args.model + '_baseline_epoch_' +
            str(i) + '.pt')
        if os.path.isfile(model_name):
Пример #4
0
train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=args.batch_size, shuffle=True,
    num_workers=args.prefetch, pin_memory=True)
test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=args.test_bs, shuffle=False,
    num_workers=args.prefetch, pin_memory=True)

# Create model
if args.model == 'densenet':
    args.decay = 0.0001
    args.epochs = 200
    net = densenet(num_classes=num_classes)
elif args.model == 'wrn':
    net = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate)
elif args.model == 'allconv':
    net = AllConvNet(num_classes)
elif args.model == 'resnext':
    args.epochs = 200
    net = resnext29(num_classes=num_classes)

state = {k: v for k, v in args._get_kwargs()}
print(state)

start_epoch = 0

# Restore model if desired
if args.load != '':
    for i in range(1000 - 1, -1, -1):
        model_name = os.path.join(args.load, args.dataset + '_' + args.model +
                                  '_baseline_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
Пример #5
0
#     num_classes = 100

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           num_workers=args.prefetch,
                                           pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=args.test_bs,
                                          shuffle=False,
                                          num_workers=args.prefetch,
                                          pin_memory=True)

# Create model
if args.model == 'allconv':
    net = AllConvNet(num_classes)
else:
    net = WideResNet(args.layers,
                     num_classes,
                     args.widen_factor,
                     dropRate=args.droprate)

if args.ngpu > 1:
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

if args.ngpu > 0:
    net.cuda()
    torch.cuda.manual_seed(1)

cudnn.benchmark = True  # fire on all cylinders
Пример #6
0
train_data, val_data = validation_split(train_data, val_share=0.1)
val_loader = torch.utils.data.DataLoader(val_data,
                                         batch_size=args.test_bs,
                                         shuffle=False,
                                         num_workers=args.prefetch,
                                         pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=args.test_bs,
                                          shuffle=False,
                                          num_workers=args.prefetch,
                                          pin_memory=True)

# Create model
if 'allconv' in args.method_name:
    net = AllConvNet(num_classes)
else:
    net = WideResNet(args.layers,
                     num_classes,
                     args.widen_factor,
                     dropRate=args.droprate)

start_epoch = 0

# Restore model
if args.load != '':
    for i in range(1000 - 1, -1, -1):
        if 'baseline' in args.method_name:
            subdir = 'baseline'
        elif 'oe_tune' in args.method_name:
            subdir = 'oe_tune'
Пример #7
0
    # train_indices = np.array(list(set(np.arange(50000)) - set(test_indices)))
    # test_data.test_data = np.copy(train_data.train_data[test_indices])
    # train_data.train_data = np.copy(train_data.train_data[train_indices])
    # #


train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=args.batch_size, shuffle=True,
    num_workers=args.prefetch, pin_memory=True)
test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=args.test_bs, shuffle=False,
    num_workers=args.prefetch, pin_memory=True)

# Create model
if args.model == 'allconv':
    net = AllConvNet(args.num_classes_pretrained_net)
else:
    net = WideResNet(args.layers, args.num_classes_pretrained_net, args.widen_factor, dropRate=args.droprate)

net = nn.DataParallel(net)

model_found = False
# Restore model if desired
if args.load != '':
    for i in range(1000 - 1, -1, -1):
        model_name = os.path.join(args.load, 'imagenet_' + args.model +
                                  '_baseline_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
            net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch:', i)
            start_epoch = i + 1
Пример #8
0
if args.dataset == 'cifar10':
    test_data = dset.CIFAR10('/data/sauravkadavath/cifar10-dataset/', train=False, transform=test_transform)
    num_classes = 10
else:
    test_data = dset.CIFAR100('/data/sauravkadavath/cifar10-dataset/', train=False, transform=test_transform)
    num_classes = 100


test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=args.batch_size, shuffle=False,
    num_workers=args.prefetch, pin_memory=True)

# Create model
if args.model == 'allconv':
    net = AllConvNet(num_classes).cuda()

net.load_state_dict(torch.load(args.load))

adversary = attacks.PGD(epsilon=8./255, num_steps=10, step_size=0.5/255).cuda()

# def train():
#     net.train()  # enter train mode
#     loss_avg = 0.0
#     for bx, by in tqdm(train_loader):

#         bx, by, = bx.cuda(), by.cuda()

#         adv_bx = adversary(net, bx, by)

#         # print(torch.max(bx), torch.min(bx), torch.mean(bx))
Пример #9
0
    batch_size=args.batch_size, shuffle=True,
    num_workers=args.prefetch, pin_memory=True)

train_loader_out = torch.utils.data.DataLoader(
    ood_data,
    batch_size=args.oe_batch_size, shuffle=False,
    num_workers=args.prefetch, pin_memory=True)

test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=args.batch_size, shuffle=False,
    num_workers=args.prefetch, pin_memory=True)

# Create model
if args.model == 'allconv':
    net = AllConvNet(num_classes)
else:
    net = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate)

start_epoch = 0

# Restore model if desired
if args.load != '':
    for i in range(1000 - 1, -1, -1):
        model_name = os.path.join(args.load, calib_indicator + args.model + '_oe_scratch_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
            net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch:', i)
            start_epoch = i + 1
            break
    if start_epoch == 0:
Пример #10
0
class ModelBaseline(object):
    def __init__(self, flags):

        self.setup(flags)
        self.setup_path(flags)
        self.configure(flags)

    def setup(self, flags):
        torch.backends.cudnn.deterministic = flags.deterministic
        print('torch.backends.cudnn.deterministic:', torch.backends.cudnn.deterministic)
        fix_all_seed(flags.seed)

        if flags.dataset == 'cifar10':
            num_classes = 10
        else:
            num_classes = 100

        if flags.model == 'densenet':
            self.network = densenet(num_classes=num_classes)
        elif flags.model == 'wrn':
            self.network = WideResNet(flags.layers, num_classes, flags.widen_factor, flags.droprate)
        elif flags.model == 'allconv':
            self.network = AllConvNet(num_classes)
        elif flags.model == 'resnext':
            self.network = resnext29(num_classes=num_classes)
        else:
            raise Exception('Unknown model.')
        self.network = self.network.cuda()

        print(self.network)
        print('flags:', flags)
        if not os.path.exists(flags.logs):
            os.makedirs(flags.logs)

        flags_log = os.path.join(flags.logs, 'flags_log.txt')
        write_log(flags, flags_log)

    def setup_path(self, flags):

        root_folder = 'data'
        if not os.path.exists(flags.logs):
            os.makedirs(flags.logs)

        self.preprocess = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize([0.5] * 3, [0.5] * 3)])
        self.train_transform = transforms.Compose(
            [transforms.RandomHorizontalFlip(),
             transforms.RandomCrop(32, padding=4),
             self.preprocess])
        self.test_transform = self.preprocess

        if flags.dataset == 'cifar10':
            self.train_data = datasets.CIFAR10(root_folder, train=True, transform=self.train_transform, download=True)
            self.test_data = datasets.CIFAR10(root_folder, train=False, transform=self.test_transform, download=True)
            self.base_c_path = os.path.join(root_folder, "CIFAR-10-C")
        else:
            self.train_data = datasets.CIFAR100(root_folder, train=True, transform=self.train_transform, download=True)
            self.test_data = datasets.CIFAR100(root_folder, train=False, transform=self.test_transform, download=True)
            self.base_c_path = os.path.join(root_folder, "CIFAR-100-C")

        self.train_loader = torch.utils.data.DataLoader(
            self.train_data,
            batch_size=flags.batch_size,
            shuffle=True,
            num_workers=flags.num_workers,
            pin_memory=True)

    def configure(self, flags):

        for name, param in self.network.named_parameters():
            print(name, param.size())

        self.optimizer = torch.optim.SGD(
            self.network.parameters(),
            flags.lr,
            momentum=flags.momentum,
            weight_decay=flags.weight_decay,
            nesterov=True)

        self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, len(self.train_loader) * flags.epochs)
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def train(self, flags):
        self.network.train()
        self.best_accuracy_test = -1

        for epoch in range(0, flags.epochs):
            for i, (images_train, labels_train) in enumerate(self.train_loader):

                # wrap the inputs and labels in Variable
                inputs, labels = images_train.cuda(), labels_train.cuda()

                # forward with the adapted parameters
                outputs, _ = self.network(x=inputs)

                # loss
                loss = self.loss_fn(outputs, labels)

                # init the grad to zeros first
                self.optimizer.zero_grad()

                # backward your network
                loss.backward()

                # optimize the parameters
                self.optimizer.step()
                self.scheduler.step()

                if epoch < 5 or epoch % 5 == 0:
                    print(
                        'epoch:', epoch, 'ite', i, 'total loss:', loss.cpu().item(), 'lr:',
                        self.scheduler.get_lr()[0])

                flags_log = os.path.join(flags.logs, 'loss_log.txt')
                write_log(str(loss.item()), flags_log)

            self.test_workflow(epoch, flags)

    def test_workflow(self, epoch, flags):

        """Evaluate network on given corrupted dataset."""
        accuracies = []
        for count, corruption in enumerate(CORRUPTIONS):
            # Reference to original data is mutated
            self.test_data.data = np.load(os.path.join(self.base_c_path, corruption + '.npy'))
            self.test_data.targets = torch.LongTensor(np.load(os.path.join(self.base_c_path, 'labels.npy')))

            test_loader = torch.utils.data.DataLoader(
                self.test_data,
                batch_size=flags.batch_size,
                shuffle=False,
                num_workers=flags.num_workers,
                pin_memory=True)

            accuracy_test = self.test(test_loader, epoch, log_dir=flags.logs, log_prefix='test_index_{}'.format(count))
            accuracies.append(accuracy_test)

        mean_acc = np.mean(accuracies)

        if mean_acc > self.best_accuracy_test:
            self.best_accuracy_test = mean_acc

            f = open(os.path.join(flags.logs, 'best_test.txt'), mode='a')
            f.write('epoch:{}, best test accuracy:{}\n'.format(epoch, self.best_accuracy_test))
            f.close()

            if not os.path.exists(flags.model_path):
                os.makedirs(flags.model_path)

            outfile = os.path.join(flags.model_path, 'best_model.tar')
            torch.save({'epoch': epoch, 'state': self.network.state_dict()}, outfile)

    def test(self, test_loader, epoch, log_prefix, log_dir='logs/'):

        # switch on the network test mode
        self.network.eval()

        total_correct = 0
        with torch.no_grad():
            for images, targets in test_loader:
                images, targets = images.cuda(), targets.cuda()
                logits, _ = self.network(images)
                pred = logits.data.max(1)[1]
                total_correct += pred.eq(targets.data).sum().item()

        accuracy = total_correct / len(test_loader.dataset)
        print('----------accuracy test----------:', accuracy)

        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        f = open(os.path.join(log_dir, '{}.txt'.format(log_prefix)), mode='a')
        f.write('epoch:{}, accuracy:{}\n'.format(epoch, accuracy))
        f.close()

        # switch on the network train mode
        self.network.train()

        return accuracy