subdir = 'oe_tune'
        else:
            subdir = 'oe_scratch'

        model_name = os.path.join(
            os.path.join(args.load, subdir),
            args.method_name + '_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
            net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch:', i)
            start_epoch = i + 1
            break
    if start_epoch == 0:
        assert False, "could not resume"

net.eval()

if args.ngpu > 1:
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

if args.ngpu > 0:
    net.cuda()
    # torch.cuda.manual_seed(1)

cudnn.benchmark = True  # fire on all cylinders

# /////////////// Calibration Prelims ///////////////

ood_num_examples = len(test_data) // 5
expected_ap = ood_num_examples / (ood_num_examples + len(test_data))
Example #2
0
class ModelBaseline(object):
    def __init__(self, flags):

        self.setup(flags)
        self.setup_path(flags)
        self.configure(flags)

    def setup(self, flags):
        torch.backends.cudnn.deterministic = flags.deterministic
        print('torch.backends.cudnn.deterministic:', torch.backends.cudnn.deterministic)
        fix_all_seed(flags.seed)

        if flags.dataset == 'cifar10':
            num_classes = 10
        else:
            num_classes = 100

        if flags.model == 'densenet':
            self.network = densenet(num_classes=num_classes)
        elif flags.model == 'wrn':
            self.network = WideResNet(flags.layers, num_classes, flags.widen_factor, flags.droprate)
        elif flags.model == 'allconv':
            self.network = AllConvNet(num_classes)
        elif flags.model == 'resnext':
            self.network = resnext29(num_classes=num_classes)
        else:
            raise Exception('Unknown model.')
        self.network = self.network.cuda()

        print(self.network)
        print('flags:', flags)
        if not os.path.exists(flags.logs):
            os.makedirs(flags.logs)

        flags_log = os.path.join(flags.logs, 'flags_log.txt')
        write_log(flags, flags_log)

    def setup_path(self, flags):

        root_folder = 'data'
        if not os.path.exists(flags.logs):
            os.makedirs(flags.logs)

        self.preprocess = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize([0.5] * 3, [0.5] * 3)])
        self.train_transform = transforms.Compose(
            [transforms.RandomHorizontalFlip(),
             transforms.RandomCrop(32, padding=4),
             self.preprocess])
        self.test_transform = self.preprocess

        if flags.dataset == 'cifar10':
            self.train_data = datasets.CIFAR10(root_folder, train=True, transform=self.train_transform, download=True)
            self.test_data = datasets.CIFAR10(root_folder, train=False, transform=self.test_transform, download=True)
            self.base_c_path = os.path.join(root_folder, "CIFAR-10-C")
        else:
            self.train_data = datasets.CIFAR100(root_folder, train=True, transform=self.train_transform, download=True)
            self.test_data = datasets.CIFAR100(root_folder, train=False, transform=self.test_transform, download=True)
            self.base_c_path = os.path.join(root_folder, "CIFAR-100-C")

        self.train_loader = torch.utils.data.DataLoader(
            self.train_data,
            batch_size=flags.batch_size,
            shuffle=True,
            num_workers=flags.num_workers,
            pin_memory=True)

    def configure(self, flags):

        for name, param in self.network.named_parameters():
            print(name, param.size())

        self.optimizer = torch.optim.SGD(
            self.network.parameters(),
            flags.lr,
            momentum=flags.momentum,
            weight_decay=flags.weight_decay,
            nesterov=True)

        self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, len(self.train_loader) * flags.epochs)
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def train(self, flags):
        self.network.train()
        self.best_accuracy_test = -1

        for epoch in range(0, flags.epochs):
            for i, (images_train, labels_train) in enumerate(self.train_loader):

                # wrap the inputs and labels in Variable
                inputs, labels = images_train.cuda(), labels_train.cuda()

                # forward with the adapted parameters
                outputs, _ = self.network(x=inputs)

                # loss
                loss = self.loss_fn(outputs, labels)

                # init the grad to zeros first
                self.optimizer.zero_grad()

                # backward your network
                loss.backward()

                # optimize the parameters
                self.optimizer.step()
                self.scheduler.step()

                if epoch < 5 or epoch % 5 == 0:
                    print(
                        'epoch:', epoch, 'ite', i, 'total loss:', loss.cpu().item(), 'lr:',
                        self.scheduler.get_lr()[0])

                flags_log = os.path.join(flags.logs, 'loss_log.txt')
                write_log(str(loss.item()), flags_log)

            self.test_workflow(epoch, flags)

    def test_workflow(self, epoch, flags):

        """Evaluate network on given corrupted dataset."""
        accuracies = []
        for count, corruption in enumerate(CORRUPTIONS):
            # Reference to original data is mutated
            self.test_data.data = np.load(os.path.join(self.base_c_path, corruption + '.npy'))
            self.test_data.targets = torch.LongTensor(np.load(os.path.join(self.base_c_path, 'labels.npy')))

            test_loader = torch.utils.data.DataLoader(
                self.test_data,
                batch_size=flags.batch_size,
                shuffle=False,
                num_workers=flags.num_workers,
                pin_memory=True)

            accuracy_test = self.test(test_loader, epoch, log_dir=flags.logs, log_prefix='test_index_{}'.format(count))
            accuracies.append(accuracy_test)

        mean_acc = np.mean(accuracies)

        if mean_acc > self.best_accuracy_test:
            self.best_accuracy_test = mean_acc

            f = open(os.path.join(flags.logs, 'best_test.txt'), mode='a')
            f.write('epoch:{}, best test accuracy:{}\n'.format(epoch, self.best_accuracy_test))
            f.close()

            if not os.path.exists(flags.model_path):
                os.makedirs(flags.model_path)

            outfile = os.path.join(flags.model_path, 'best_model.tar')
            torch.save({'epoch': epoch, 'state': self.network.state_dict()}, outfile)

    def test(self, test_loader, epoch, log_prefix, log_dir='logs/'):

        # switch on the network test mode
        self.network.eval()

        total_correct = 0
        with torch.no_grad():
            for images, targets in test_loader:
                images, targets = images.cuda(), targets.cuda()
                logits, _ = self.network(images)
                pred = logits.data.max(1)[1]
                total_correct += pred.eq(targets.data).sum().item()

        accuracy = total_correct / len(test_loader.dataset)
        print('----------accuracy test----------:', accuracy)

        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        f = open(os.path.join(log_dir, '{}.txt'.format(log_prefix)), mode='a')
        f.write('epoch:{}, accuracy:{}\n'.format(epoch, accuracy))
        f.close()

        # switch on the network train mode
        self.network.train()

        return accuracy