def get_classifier(classes, d, pretrained):
    global cnn_
    feature_extracting = False

    def set_parameter_requires_grad(model, feature_extracting):
        if feature_extracting:
            for param in model.parameters():
                param.requires_grad = False

    if d == "vgg":
        cnn_ = vgg19_bn(num_classes=classes)
    elif d == "wideresnet":
        cnn_ = WideResNet(depth=28,
                          num_classes=classes,
                          widen_factor=10,
                          dropRate=0.3)
    elif d == "wideresnet2":
        cnn_ = WideResNet2(num_classes=classes)
    # cnn_ = torch.hub.load('pytorch/vision:v0.4.2', 'wide_resnet101_2', pretrained=True)
    elif d == 'densenet':
        cnn_ = torchvision.models.densenet121(pretrained=pretrained)
        num_ftrs = cnn_.classifier.in_features
        cnn_.classifier = nn.Linear(num_ftrs, classes)
    elif d == "resnet":
        cnn_ = torchvision.models.resnet101(pretrained=pretrained)
        set_parameter_requires_grad(cnn_, feature_extracting)
        num_ftrs = cnn_.fc.in_features
        cnn_.fc = nn.Linear(num_ftrs, classes)
    elif d == "resnet50":
        cnn_ = torchvision.models.resnet50(pretrained=pretrained)
        set_parameter_requires_grad(cnn_, feature_extracting)
        num_ftrs = cnn_.fc.in_features
        cnn_.fc = nn.Linear(num_ftrs, classes)
    elif d == "resnet18":
        cnn_ = torchvision.models.resnet18(pretrained=pretrained)
        set_parameter_requires_grad(cnn_, feature_extracting)
        num_ftrs = cnn_.fc.in_features
        cnn_.fc = nn.Linear(num_ftrs, classes)
    elif d == "conv":
        cnn_ = cnn(num_classes=classes, im_size=32)  # Pixel space
        cnn_.apply(weights_init)
    return cnn_
Example #2
0
def main(args):
    np.random.seed(0)
    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Constructing Model

    if args.resume != "":
        if os.path.isfile(args.resume):
            print("=> Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            test_only = args.test_only
            resume = args.resume
            args = checkpoint["opt"]
            args.test_only = test_only
            args.resume = resume
        else:
            checkpoint = None
            print("=> No checkpoint found at '{}'".format(args.resume))

    model = WideResNet(args.depth, args.widen_factor, args.dropout_rate,
                       args.num_classes)

    if torch.cuda.is_available():
        model.cuda()
        model = torch.nn.DataParallel(model, device_ids=args.gpu)

    if args.resume != "":
        model.load_state_dict(checkpoint["model"])
        args.start_epoch = checkpoint["epoch"] + 1
        print("=> Loaded successfully '{}' (epoch {})".format(
            args.resume, checkpoint["epoch"]))
        del checkpoint
        torch.cuda.empty_cache()
    else:
        model.apply(conv_init)

    # Loading Dataset

    if args.augment == "meanstd":
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(Config.CIFAR10_mean, Config.CIFAR10_std),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(Config.CIFAR10_mean, Config.CIFAR10_std),
        ])
    elif args.augment == "zac":
        # To Do: ZCA whitening
        pass
    else:
        raise NotImplementedError

    print("| Preparing CIFAR-10 dataset...")
    sys.stdout.write("| ")
    trainset = CIFAR10(root="./data",
                       train=True,
                       download=True,
                       transform=transform_train)
    testset = CIFAR10(root="./data",
                      train=False,
                      download=False,
                      transform=transform_test)

    train_loader = DataLoader(trainset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=2)
    test_loader = DataLoader(testset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=2)

    # Test only

    if args.test_only:
        if args.resume != "":
            test(args, test_loader, model)
            sys.exit(0)
        else:
            print("=> Test only model need to resume from a checkpoint")
            raise RuntimeError

    train(args, train_loader, test_loader, model)
    test(args, test_loader, model)