Exemplo n.º 1
0
from model import *

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter
import torchvision
from torchvision.models import densenet121, wide_resnet50_2, resnet34
from densenet_mod import _DenseBlock
from mish.mish import Mish
import pretrainedmodels
    
res34 = resnet34(pretrained=True)
seresxt50 = pretrainedmodels.__dict__['se_resnext50_32x4d'](pretrained='imagenet')
dense121 = densenet121(pretrained=True)
wideres50 = wide_resnet50_2(pretrained=True)

# ------------------------------------------------------------------------
def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM,self).__init__()
        self.p = Parameter(torch.ones(1)*p)
        self.eps = eps
    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps)       
    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'
def get(config=None):
    name = config.model.name
    classes = config.classes
    pred_type = config.model.params.pred_type
    tune_type = config.model.params.tune_type

    adjusted_classes = classes
    if pred_type == 'REG':
        adjusted_classes = 1
    elif pred_type == 'MIX':
        adjusted_classes = classes + 1

    # ===========================================================================
    #                                 Model list
    # ===========================================================================

    if name == 'densenet161':
        model = models.densenet161(pretrained=True)
        if tune_type == 'FE':
            for param in model.parameters():
                param.requires_grad = False
        num_ftrs = model.classifier.in_features
        model.classifier = get_default_fc(num_ftrs, adjusted_classes,
                                          config.model.params)
    elif name == 'densenet201':
        model = models.densenet201(pretrained=True)
        if tune_type == 'FE':
            for param in model.parameters():
                param.requires_grad = False
        num_ftrs = model.classifier.in_features
        model.classifier = get_default_fc(num_ftrs, adjusted_classes,
                                          config.model.params)
    elif name == 'resnet50':
        model = models.resnet50(pretrained=True)
        if tune_type == 'FE':
            for param in model.parameters():
                param.requires_grad = False
        # model.avgpool = GeM()
        model.avgpool = nn.AdaptiveMaxPool2d(1)
        num_ftrs = model.fc.in_features
        model.fc = get_default_fc(num_ftrs, adjusted_classes,
                                  config.model.params)
    elif name == 'resnet101':
        model = models.resnet101(pretrained=True)
        if tune_type == 'FE':
            for param in model.parameters():
                param.requires_grad = False
        num_ftrs = model.fc.in_features
        model.fc = get_default_fc(num_ftrs, adjusted_classes,
                                  config.model.params)
    elif name == 'resnet152':
        model = models.resnet152(pretrained=True)
        if tune_type == 'FE':
            for param in model.parameters():
                param.requires_grad = False
        num_ftrs = model.fc.in_features
        model.fc = get_default_fc(num_ftrs, adjusted_classes,
                                  config.model.params)
    elif name == 'resnext50_32x4d':
        model = models.resnext50_32x4d(pretrained=True)
        if tune_type == 'FE':
            for param in model.parameters():
                param.requires_grad = False
        num_ftrs = model.fc.in_features
        model.fc = get_default_fc(num_ftrs, adjusted_classes,
                                  config.model.params)
    elif name == 'resnext101_32x8d':
        model = models.resnext101_32x8d(pretrained=True)
        if tune_type == 'FE':
            for param in model.parameters():
                param.requires_grad = False
        num_ftrs = model.fc.in_features
        model.fc = get_default_fc(num_ftrs, adjusted_classes,
                                  config.model.params)
    elif name == 'wide_resnet50_2':
        model = models.wide_resnet50_2(pretrained=True)
        if tune_type == 'FE':
            for param in model.parameters():
                param.requires_grad = False
        num_ftrs = model.fc.in_features
        model.fc = get_default_fc(num_ftrs, adjusted_classes,
                                  config.model.params)
    elif name == 'wide_resnet101_2':
        model = models.wide_resnet101_2(pretrained=True)
        if tune_type == 'FE':
            for param in model.parameters():
                param.requires_grad = False
        num_ftrs = model.fc.in_features
        model.fc = get_default_fc(num_ftrs, adjusted_classes,
                                  config.model.params)
    elif name == 'efficientnet-b0':
        model = EfficientNet.from_pretrained('efficientnet-b0')
        if tune_type == 'FE':
            for param in model.parameters():
                param.requires_grad = False
        num_ftrs = model._fc.in_features
        model._fc = get_default_fc(num_ftrs, adjusted_classes,
                                   config.model.params)
    elif name == 'efficientnet-b1':
        model = EfficientNet.from_pretrained('efficientnet-b1')
        if tune_type == 'FE':
            for param in model.parameters():
                param.requires_grad = False
        model._avg_pooling = nn.AdaptiveMaxPool2d(1)
        num_ftrs = model._fc.in_features
        model._fc = get_default_fc(num_ftrs, adjusted_classes,
                                   config.model.params)
    elif name == 'efficientnet-b5':
        model = EfficientNet.from_pretrained('efficientnet-b5')
        if tune_type == 'FE':
            for param in model.parameters():
                param.requires_grad = False
        num_ftrs = model._fc.in_features
        model._fc = get_default_fc(num_ftrs, adjusted_classes,
                                   config.model.params)
    elif name == 'pann-cnn14-attn':
        model = Pann_Cnn14_Attn(pretrained=True)
        model.att_block = AttBlock(2048,
                                   adjusted_classes,
                                   activation='sigmoid')
    else:
        raise Exception("model not in list!")

    print("[ Model : {} ]".format(name))
    print("↳ [ Prediction type : {} ]".format(pred_type))
    print("↳ [ Adjusted classes : {} ]".format(adjusted_classes))
    if config.mode != "PRD":
        print("↳ [ Tuning type : {} ]".format(tune_type))
    return model
Exemplo n.º 3
0
def tWRN50_2(n_classes, n_channels):
    return wide_resnet50_2(num_classes=n_classes)
Exemplo n.º 4
0
 },
 #"googlenet": models.googlenet(pretrained=True),
 "shufflenet": {
     "model": models.shufflenet_v2_x1_0(pretrained=True),
     "path": "both"
 },
 "mobilenet_v2": {
     "model": models.mobilenet_v2(pretrained=True),
     "path": "both"
 },
 "resnext50_32x4d": {
     "model": models.resnext50_32x4d(pretrained=True),
     "path": "both"
 },
 "wideresnet50_2": {
     "model": models.wide_resnet50_2(pretrained=True),
     "path": "both"
 },
 "mnasnet": {
     "model": models.mnasnet1_0(pretrained=True),
     "path": "both"
 },
 "resnet18": {
     "model":
     torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True),
     "path":
     "both"
 },
 "resnet50": {
     "model":
     torch.hub.load('pytorch/vision:v0.9.0', 'resnet50', pretrained=True),
Exemplo n.º 5
0
def wide_resnet50_2():
    return models.wide_resnet50_2(pretrained=True)
Exemplo n.º 6
0
 def test_wide_resnet50_2(self):
     process_model(models.wide_resnet50_2(), self.image,
                   _C_tests.forward_wide_resnet50_2, "WideResNet50_2")
Exemplo n.º 7
0
 def _get_model(self, key):
     if key == 'vgg11':
         return (torch_models.vgg11(True), flax_models.vgg11(RNG))
     if key == 'vgg11_bn':
         return (torch_models.vgg11_bn(True), flax_models.vgg11_bn(RNG))
     if key == 'vgg13':
         return (torch_models.vgg13(True), flax_models.vgg13(RNG))
     if key == 'vgg13_bn':
         return (torch_models.vgg13_bn(True), flax_models.vgg13_bn(RNG))
     if key == 'vgg16':
         return (torch_models.vgg16(True), flax_models.vgg16(RNG))
     if key == 'vgg16_bn':
         return (torch_models.vgg16_bn(True), flax_models.vgg16_bn(RNG))
     if key == 'vgg19':
         return (torch_models.vgg19(True), flax_models.vgg19(RNG))
     if key == 'vgg19_bn':
         return (torch_models.vgg19_bn(True), flax_models.vgg19_bn(RNG))
     if key == 'resnet18':
         return (torch_models.resnet18(True), flax_models.resnet18(RNG))
     if key == 'resnet34':
         return (torch_models.resnet34(True), flax_models.resnet34(RNG))
     if key == 'resnet50':
         return (torch_models.resnet50(True), flax_models.resnet50(RNG))
     if key == 'resnet101':
         return (torch_models.resnet101(True), flax_models.resnet101(RNG))
     if key == 'resnet152':
         return (torch_models.resnet152(True), flax_models.resnet152(RNG))
     if key == 'resnext50_32x4d':
         return (torch_models.resnext50_32x4d(True),
                 flax_models.resnext50_32x4d(RNG))
     if key == 'resnext101_32x8d':
         return (torch_models.resnext101_32x8d(True),
                 flax_models.resnext101_32x8d(RNG))
     if key == 'wide_resnet50_2':
         return (torch_models.wide_resnet50_2(True),
                 flax_models.wide_resnet50_2(RNG))
     if key == 'wide_resnet101_2':
         return (torch_models.wide_resnet101_2(True),
                 flax_models.wide_resnet101_2(RNG))
     if key == 'inception_v3':
         return (torch_models.inception_v3(True),
                 flax_models.inception_v3(RNG))
     if key == 'densenet121':
         return (torch_models.densenet121(True),
                 flax_models.densenet121(RNG))
     if key == 'densenet161':
         return (torch_models.densenet161(True),
                 flax_models.densenet161(RNG))
     if key == 'densenet169':
         return (torch_models.densenet169(True),
                 flax_models.densenet169(RNG))
     if key == 'densenet201':
         return (torch_models.densenet201(True),
                 flax_models.densenet201(RNG))
     if key == 'fcn_resnet50':
         return (torch_models.segmentation.fcn_resnet50(True),
                 flax_models.fcn_resnet50(RNG))
     if key == 'fcn_resnet101':
         return (torch_models.segmentation.fcn_resnet101(True),
                 flax_models.fcn_resnet101(RNG))
     if key == 'deeplabv3_resnet50':
         return (torch_models.segmentation.deeplabv3_resnet50(True),
                 flax_models.deeplabv3_resnet50(RNG))
     if key == 'deeplabv3_resnet101':
         return (torch_models.segmentation.deeplabv3_resnet101(True),
                 flax_models.deeplabv3_resnet101(RNG))
Exemplo n.º 8
0
def main(gpu, args):
    ckpt_path_save = os.path.join(
        args.ckpt_path_save,
        args.experiment_name + "_" + str(args.experiment_id))
    if not os.path.exists(ckpt_path_save):
        os.makedirs(ckpt_path_save)
    print("ckpt_path_save:", ckpt_path_save)

    rank = args.nr * args.gpus + gpu
    dist.init_process_group("nccl", rank=rank, world_size=args.world_size)

    torch.manual_seed(0)
    torch.cuda.set_device(gpu)
    normalize = transforms.Normalize(mean=[0.4771, 0.4769, 0.4355],
                                     std=[0.2189, 0.1199, 0.1717])

    # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709
    augmentation = [
        transforms.RandomCrop(args.image_size),
        NonLinearColorJitter(),
        # transforms.RandomApply([
        #     transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        # ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([loader.GaussianBlur([.1, 2.])], p=0.5),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.RandomErasing(),
        normalize
    ]
    transform = loader.TwoCropsTransform(transforms.Compose(augmentation))
    # dataset
    traindir = [
        "/data/gukedata/train_data/0-10", "/data/gukedata/train_data/11-15",
        "/data/gukedata/train_data/16-20", "/data/gukedata/train_data/21-25",
        "/data/gukedata/train_data/26-45", "/data/gukedata/train_data/46-",
        "/data/gukedata/test_data/0-10", "/data/gukedata/test_data/11-15",
        "/data/gukedata/test_data/16-20", "/data/gukedata/test_data/21-25",
        "/data/gukedata/test_data/26-45", "/data/gukedata/test_data/46-"
    ]

    train_dataset = DegreesData(traindir, transform)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=args.world_size, rank=rank)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        drop_last=True,
        num_workers=args.num_workers,
        pin_memory=True,
        sampler=train_sampler,
    )

    # model
    model = models.wide_resnet50_2()

    model = BYOL(model, image_size=args.image_size, hidden_layer="avgpool")
    model = model.cuda(gpu)

    # distributed data parallel
    model = DDP(model, device_ids=[gpu], find_unused_parameters=True)

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    log_path = os.path.join(
        args.log_path, args.experiment_name + "_" + str(args.experiment_id))
    # TensorBoard writer

    if gpu == 0:
        writer = SummaryWriter(log_path)

    # solver
    global_step = 0
    for epoch in range(args.num_epochs):
        lr = adjust_learning_rate(optimizer, epoch, args)
        metrics = defaultdict(list)
        for step, ((x_i, x_j), _) in enumerate(train_loader):
            x_i = x_i.cuda(non_blocking=True)
            x_j = x_j.cuda(non_blocking=True)

            loss = model(x_i, x_j)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            model.module.update_moving_average(
            )  # update moving average of target encoder

            if step % 1 == 0 and gpu == 0:
                print(
                    f"Step [{step}/{len(train_loader)}]:\tLoss: {loss.item()}")

            if gpu == 0:
                writer.add_scalar("Loss/train_step", loss, global_step)
                metrics["Loss/train"].append(loss.item())
                global_step += 1

        for param_group in optimizer.param_groups:
            lr = param_group['lr']
            if gpu == 0:
                print("Epoch:", epoch, 'Learning_rate:', lr)
                writer.add_scalar('Learning_rate', lr, epoch)
            break
        if gpu == 0:
            # write metrics to TensorBoard
            for k, v in metrics.items():
                writer.add_scalar(k, np.array(v).mean(), epoch)

            if epoch % args.checkpoint_epochs == 0:
                if gpu == 0:
                    print(f"Saving model at epoch {epoch}")
                    torch.save(model.state_dict(),
                               ckpt_path_save + "/model-{epoch}.pt")

                # let other workers wait until model is finished
                # dist.barrier()

    # save your improved network
    if gpu == 0:
        torch.save(model.state_dict(), ckpt_path_save + "/model-final.pt")

    cleanup()
Exemplo n.º 9
0
def training(rank, world_size, backend, config):
    # Specific xla
    print(xm.get_ordinal(), ": run with config:", config, "- backend=", backend)
    device = xm.xla_device()

    # Data preparation
    dataset = RndDataset(nb_samples=config["nb_samples"])

    # Specific xla
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(),
    )
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=int(config["batch_size"] / xm.xrt_world_size()),
        num_workers=1,
        sampler=train_sampler,
    )

    # Specific xla
    para_loader = pl.MpDeviceLoader(train_loader, device)

    # Model, criterion, optimizer setup
    model = wide_resnet50_2(num_classes=100).to(device)
    criterion = NLLLoss()
    optimizer = SGD(model.parameters(), lr=0.01)

    # Training loop log param
    log_interval = config["log_interval"]

    def _train_step(batch_idx, data, target):

        data = data
        target = target

        optimizer.zero_grad()
        output = model(data)
        # Add a softmax layer
        probabilities = torch.nn.functional.softmax(output, dim=0)

        loss_val = criterion(probabilities, target)
        loss_val.backward()
        xm.optimizer_step(optimizer)

        if batch_idx % log_interval == 0:
            print(
                "Process {}/{} Train Epoch: {} [{}/{}]\tLoss: {}".format(
                    xm.get_ordinal(),
                    xm.xrt_world_size(),
                    epoch,
                    batch_idx * len(data),
                    len(train_sampler),
                    loss_val.item(),
                )
            )
        return loss_val

    # Running _train_step for n_epochs
    n_epochs = 1
    for epoch in range(n_epochs):
        for batch_idx, (data, target) in enumerate(para_loader):
            _train_step(batch_idx, data, target)
Exemplo n.º 10
0
import torch
import cv2
from torchvision import transforms
from PIL import Image
from resizeimage import resizeimage
import numpy as np

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# import os
# os.environ['TORCH_HOME'] = '/home/err_pv/Desktop/Parikh_linux/Deep Learning/openCV/Image-Editor' #setting the environment variable
wide_resnet = models.wide_resnet50_2(pretrained=True, progress=True)

with open("trip1.jpg", 'r+b') as f:
    with Image.open(f) as image:
        cover = resizeimage.resize_cover(image, [600, 400])
image = cover
# # Convert RGB to BGR
# image = image[:, :, ::-1].copy()
# cv2.imshow("Press any key to continue", image)
# l = cv2.waitKey(0)
image.show()

import json
class_idx = json.load(open("imagenet_class_index.json"))
# print(class_idx)
idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
Exemplo n.º 11
0
def wrn50():
    model = wide_resnet50_2(pretrained=True)
    model.fc = nn.Linear(2048, 2)
    return model
Exemplo n.º 12
0
    data_dir = args.data_dir
    all_img_names = [
        f for f in os.listdir(data_dir)
        if os.path.splitext(f)[1].lower() in Image.EXTENSION
    ]
    random.shuffle(all_img_names)
    NUM_TRAIN_IMGS = int(args.train_ratio * len(all_img_names))
    NUM_VAL_IMGS = int(args.val_ratio * len(all_img_names))
    img_paths_train = [
        os.path.join(data_dir, f) for f in all_img_names[:NUM_TRAIN_IMGS]
    ]
    img_paths_val = [
        os.path.join(data_dir, f)
        for f in all_img_names[NUM_TRAIN_IMGS:NUM_TRAIN_IMGS + NUM_VAL_IMGS]
    ]
    model = wide_resnet50_2(pretrained=True)
    model.fc = torch.nn.Linear(2048, 1)

    optim = torch.optim.Adam([
        {
            'params': [
                p for n, p in model.named_parameters()
                if not n.startswith('fc.')
            ],
            'lr':
            1e-5
        },
        {
            'params': model.fc.parameters(),
            'lr': 1e-4
        },
Exemplo n.º 13
0
    def get_model(model_id, use_pretrained):
        model_ft = None
        if model_id == PyTorchModelsEnum.ALEXNET:
            model_ft = models.alexnet(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.DENSENET121:
            model_ft = models.densenet121(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.DENSENET161:
            model_ft = models.densenet161(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.DENSENET169:
            model_ft = models.densenet169(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.DENSENET201:
            model_ft = models.densenet201(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.GOOGLENET:
            model_ft = models.googlenet(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.INCEPTION_V3:
            model_ft = models.inception_v3(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.MOBILENET_V2:
            model_ft = models.mobilenet_v2(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.MNASNET_0_5:
            model_ft = models.mnasnet0_5(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.MNASNET_0_75:  # no pretrained
            model_ft = models.mnasnet0_75(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.MNASNET_1_0:
            model_ft = models.mnasnet1_0(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.MNASNET_1_3:
            model_ft = models.mnasnet1_3(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.RESNET18:
            model_ft = models.resnet18(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.RESNET34:
            model_ft = models.resnet34(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.RESNET50:
            model_ft = models.resnet50(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.RESNET101:
            model_ft = models.resnet101(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.RESNET152:
            model_ft = models.resnet152(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.RESNEXT50:
            model_ft = models.resnext50_32x4d(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.RESNEXT101:
            model_ft = models.resnext101_32x8d(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.SHUFFLENET_V2_0_5:
            model_ft = models.shufflenet_v2_x0_5(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.SHUFFLENET_V2_1_0:
            model_ft = models.shufflenet_v2_x1_0(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.SHUFFLENET_V2_1_5:
            model_ft = models.shufflenet_v2_x1_5(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.SHUFFLENET_V2_2_0:
            model_ft = models.shufflenet_v2_x2_0(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.SQUEEZENET1_0:
            model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.SQUEEZENET1_1:
            model_ft = models.squeezenet1_1(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.VGG11:
            model_ft = models.vgg11(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.VGG11_BN:
            model_ft = models.vgg11_bn(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.VGG13:
            model_ft = models.vgg13(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.VGG13_BN:
            model_ft = models.vgg13_bn(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.VGG16:
            model_ft = models.vgg16(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.VGG16_BN:
            model_ft = models.vgg16_bn(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.VGG19:
            model_ft = models.vgg19(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.VGG19_BN:
            model_ft = models.vgg19_bn(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.WIDE_RESNET50:
            model_ft = models.wide_resnet50_2(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.WIDE_RESNET101:
            model_ft = models.wide_resnet101_2(pretrained=use_pretrained)

        return model_ft
Exemplo n.º 14
0
def instantiate_model (dataset='cifar10',
                       num_classes=10, 
                       input_quant='FP', 
                       arch='resnet',
                       dorefa=False, 
                       abit=32, 
                       wbit=32,
                       qin=False, 
                       qout=False,
                       suffix='', 
                       load=False,
                       torch_weights=False,
                       device='cpu'):
    """Initializes/load network with random weight/saved and return auto generated model name 'dataset_arch_suffix.ckpt'
    
    Args:
        dataset         : mnists/cifar10/cifar100/imagenet/tinyimagenet/simple dataset the netwoek is trained on. Used in model name 
        num_classes     : number of classes in dataset. 
        arch            : resnet/vgg/lenet5/basicnet/slpconv model architecture the network to be instantiated with 
        suffix          : str appended to the model name 
        load            : boolean variable to indicate load pretrained model from ./pretrained/dataset/
        torch_weights   : boolean variable to indicate load weight from torchvision for imagenet dataset
    Returns:
        model           : models with desired weight (pretrained / random )
        model_name      : str 'dataset_arch_suffix.ckpt' used to save/load model in ./pretrained/dataset
    """
    #Select the input transformation
    if input_quant==None:
        input_quant=''
        Q=PreProcess()
    elif input_quant.lower()=='q1':
        Q = Quantise2d(n_bits=1).to(device)
    elif input_quant.lower()=='q2':
        Q = Quantise2d(n_bits=2).to(device)
    elif input_quant.lower()=='q4':
        Q = Quantise2d(n_bits=4).to(device)
    elif input_quant.lower()=='q6':
        Q = Quantise2d(n_bits=6).to(device)
    elif input_quant.lower()=='q8':
        Q = Quantise2d(n_bits=8).to(device)
    elif input_quant.lower()=='fp':
        Q = Quantise2d(n_bits=1,quantise=False).to(device)
    else:    
        raise ValueError

    # Instantiate model1
    # RESNET IMAGENET
    if(arch == 'torch_resnet18'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.resnet18(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_resnet34'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.resnet34(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_resnet50'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.resnet50(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_resnet101'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.resnet101(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_resnet152'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.resnet152(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_resnet34'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.resnet34(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_resnext50_32x4d'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.resnext50_32x4d(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_resnext101_32x8d'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.resnext101_32x8d(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_wide_resnet50_2'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.wide_resnet50_2(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_wide_resnet101_2'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.wide_resnet101_2(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    #VGG IMAGENET
    elif(arch == 'torch_vgg11'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.vgg11(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_vgg11bn'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.vgg11_bn(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_vgg13'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.vgg13(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_vgg13bn'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.vgg13_bn(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_vgg16'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.vgg16(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_vgg16bn'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.vgg16_bn(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_vgg19'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.vgg19(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_vgg19bn'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.vgg19_bn(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    #MOBILENET IMAGENET   
    elif(arch == 'torch_mobnet'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.mobilenet_v2(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    #DENSENET IMAGENET
    elif(arch == 'torch_densenet121'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.densenet121(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_densenet169'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.densenet169(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_densenet201'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.densenet201(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    elif(arch == 'torch_densenet161'):
        if dorefa:
            raise ValueError ("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.densenet161(pretrained=torch_weights)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    #RESNET CIFAR   
    elif(arch[0:6] == 'resnet'):
        cfg = arch[6:]
        if dorefa:
            model = ResNet_Dorefa_(cfg=cfg, num_classes=num_classes, a_bit=abit, w_bit=wbit)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch +"_a" + str(abit) + 'w'+ str(wbit) + suffix
            
        else:   
            model = ResNet_(cfg=cfg, num_classes=num_classes)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    
    #VGG CIFAR
    elif(arch[0:3] == 'vgg'):
        len_arch = len(arch)
        if arch[len_arch-2:len_arch]=='bn' and arch[len_arch-4:len_arch-2]=='bn':
            batch_norm_conv=True
            batch_norm_linear=True
            cfg= arch[3:len_arch-4]
        elif arch [len_arch-2: len_arch]=='bn':
            batch_norm_conv=True
            batch_norm_linear=False
            cfg= arch[3:len_arch-2]
        else:
            batch_norm_conv=False
            batch_norm_linear=False
            cfg= arch[3:len_arch]
        if dorefa:
            model = vgg_Dorefa(cfg=cfg, batch_norm_conv=batch_norm_conv, batch_norm_linear=batch_norm_linear ,num_classes=num_classes, a_bit=abit, w_bit=wbit)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch +"_a" + str(abit) + 'w'+ str(wbit) + suffix
            
        else:   
            model = vgg(cfg=cfg, batch_norm_conv=batch_norm_conv, batch_norm_linear=batch_norm_linear ,num_classes=num_classes)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    # LENET MNIST
    elif (arch == 'lenet5'):
        if dorefa:
            model = LeNet5_Dorefa(num_classes=num_classes, abit=abit, wbit=wbit)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch +"_a" + str(abit) + 'w'+ str(wbit) + suffix
        else:
            model = LeNet5(num_classes=num_classes)
            model_name = dataset.lower()+ "_" + input_quant + "_" + arch + suffix
    else:
        # Right way to handle exception in python see https://stackoverflow.com/questions/2052390/manually-raising-throwing-an-exception-in-python
        # Explains all the traps of using exception, does a good job!! I mean the link :)
        raise ValueError("Unsupported neural net architecture")
    model = model.to(device)
    
    if load == True and torch_weights == False :
        print(" Using Model: " + arch)
        if model_name[-4:]=='_tfr':
            model_path = os.path.join('./pretrained/', dataset.lower(),  model_name + '.tfr')
        else:
            model_path = os.path.join('./pretrained/', dataset.lower(),  model_name + '.ckpt')
        model.load_state_dict(torch.load(model_path, map_location='cuda:0'))
        print(' Loaded trained model from :' + model_path)
        print(' {}'.format(Q))
    
    else:
        if model_name[-4:]=='_tfr':
            model_path = os.path.join('./pretrained/', dataset.lower(),  model_name + '.tfr')
        else:
            model_path = os.path.join('./pretrained/', dataset.lower(),  model_name + '.ckpt')
        print(' Training model save at:' + model_path)
    print('')
    return model, model_name, Q
Exemplo n.º 15
0
from PIL import Image
import numpy as np

import torch
import torchvision.models as models

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

model = models.wide_resnet50_2(pretrained=True).to(device)

im_frame = Image.open("images/" + 'panda.png')
np_frame = np.array(im_frame.getdata()).reshape(224, 224, 3) / 255

np_frame = (np_frame - mean) / std

img = torch.from_numpy(np_frame).float().to(device).permute(2, 0, 1).view(
    1, 3, 224, 224)

out = model(img)
def experiment(num_shared_classes, percent_shared_data, n_epochs=200,batch_size=128, eps=.3, adv_steps=100, learning_rate=.0004, gpu_num=1,adv_training=False,task="CIFAR100"):
    print("epochs,batch_size,eps,adv_steps,learning_rate,task")
    print(n_epochs,batch_size,eps,adv_steps,learning_rate,task)

    cuda = torch.cuda.is_available()

    transform_test = transforms.Compose(
            [transforms.ToTensor(),transforms.Normalize((0.5070751592371323, 0.48654887331495095, 0.4409178433670343), (0.2673342858792401, 0.2564384629170883, 0.27615047132568404))])

    transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5070751592371323, 0.48654887331495095, 0.4409178433670343), (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)),
            ])

    if task.upper() == "CIFAR100":
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])

        train_data = CIFAR100("data/",transform=transform_train, download=False)
        test_data = CIFAR100("data/", train=False, transform=transform_test, download=False)
    elif task.upper() == "IMAGENET":
        train_data = ImageNet('data/imagenet', split='train', download=False)
        test_data = ImageNet('data/imagenet', split='val', download=False)
    elif task.upper() == "FASHIONMNIST":
        transform = transforms.Compose([transforms.Lambda(lambda image: image.convert('RGB')),
                                        transforms.ToTensor()
                             ])

        train_data = FashionMNIST('data/fashionmnist',transform=transform, train=True, download=False)
        test_data = FashionMNIST('data/fashionmnist', transform=transform, train=False, download=False)
    else:
        train_data = CIFAR10("data/",transform=transform_train,download=False)
        test_data = CIFAR10("data/", train=False, transform=transform_test,download=False)

        # model1 = ResNet(ResidualBlock, [2, 2, 2],num_classes=10)
        # model2 = ResNet(ResidualBlock, [2, 2, 2],num_classes=10)


    all_classes = set([x[1] for x in train_data])
    shared_classes = random.sample(all_classes, num_shared_classes)
    split_classes = [c for c in all_classes if c not in shared_classes] # get classes not shared


    if len(split_classes) % 2 == 1: # if we have an odd #, randomly remove one so that number of classes will be the same for each model
        split_classes.pop(random.randint(0, len(split_classes) - 1))

    model1_split = random.sample(split_classes, len(split_classes) // 2)
    model2_split = [c for c in split_classes if c not in model1_split]


    model1_classes = model1_split
    model2_classes = model2_split

    model1_classes.sort()
    model2_classes.sort()

    # DEBUG:
    print("shared classes: {}".format(shared_classes))
    print("model1 classes: {}".format(model1_classes))
    print("model2 classes: {}".format(model2_classes))

    model1_x_train = []
    model1_y_train = []

    model2_x_train = []
    model2_y_train = []

    shared_x_train = []
    shared_y_train = []

    # train data splits
    for index in range(len(train_data)):

        current_class = train_data[index][1]

        # model 1
        if current_class in model1_classes:
            model1_x_train.append(train_data[index][0])
            model1_y_train.append(train_data[index][1])

        # model 2
        if current_class in model2_classes:
            model2_x_train.append(train_data[index][0])
            model2_y_train.append(train_data[index][1])


    # split by percentage for classes per model1

    if percent_shared_data < 100:

        new_model1_x_train = []
        new_model1_y_train = []

        for curr_class in model1_classes:
            temp_data_x = []
            temp_data_y = []

            # get all examples of class
            for i in range(len(model1_x_train)):
                if(model1_y_train[i] == curr_class):
                    temp_data_x.append(model1_x_train[i])
                    temp_data_y.append(model1_y_train[i])

            # split data by half the size
            total_size = len(temp_data_x)
            shared_size = int(total_size * .5)

            shared_indices = random.sample(list(range(len(temp_data_x))),shared_size)

            new_model1_x_train += [temp_data_x[i] for i in shared_indices]
            new_model1_y_train += [temp_data_y[i] for i in shared_indices]


        # split for model2

        new_model2_x_train = []
        new_model2_y_train = []

        for curr_class in model2_classes:
            temp_data_x = []
            temp_data_y = []

            # get all examples of class
            for i in range(len(model2_x_train)):
                if(model2_y_train[i] == curr_class):
                    temp_data_x.append(model2_x_train[i])
                    temp_data_y.append(model2_y_train[i])

            # split data by half the size
            total_size = len(temp_data_x)
            shared_size = int(total_size * .5)

            shared_indices = random.sample(list(range(len(temp_data_x))),shared_size)

            new_model2_x_train += [temp_data_x[i] for i in shared_indices]
            new_model2_y_train += [temp_data_y[i] for i in shared_indices]


        # rewrite dataset
        model1_x_train = new_model1_x_train
        model1_y_train = new_model1_y_train

        model2_x_train = new_model2_x_train
        model2_y_train = new_model2_y_train

    # Carry out datasplitting for shared classes and add to datasets

    for shared_class in shared_classes:

        all_examples_x_train = []
        all_examples_y_train = []

        # get all examples of class
        for index in range(len(train_data)):
            current_class = train_data[index][1]

            if current_class == shared_class:
                all_examples_x_train.append(train_data[index][0])
                all_examples_y_train.append(train_data[index][1])


        # find max number of samples per model (set to be amount of examples if data is completely disjoint)
        max_examples = len(all_examples_x_train) // 2

        # get shared examples
        shared_examples_x_train = []
        shared_examples_y_train = []

        num_shared_examples = max_examples * percent_shared_data // 100
        for _ in range(num_shared_examples):
            random_int = random.randint(0, len(all_examples_x_train) - 1)

            shared_examples_x_train.append(all_examples_x_train.pop(random_int))
            shared_examples_y_train.append(all_examples_y_train.pop(random_int))


        # get disjoint examples
        disjoint_examples = max_examples - len(shared_examples_x_train)

        model1_examples_x_train = []
        model1_examples_y_train = []

        model2_examples_x_train = []
        model2_examples_y_train = []

        for _ in range(disjoint_examples):
            model1_rand_int = random.randint(0, len(all_examples_x_train) - 1)

            model1_examples_x_train.append(all_examples_x_train.pop(model1_rand_int))
            model1_examples_y_train.append(all_examples_y_train.pop(model1_rand_int))

            model2_rand_int = random.randint(0, len(all_examples_x_train) - 1)
            model2_examples_x_train.append(all_examples_x_train.pop(model2_rand_int))
            model2_examples_y_train.append(all_examples_y_train.pop(model2_rand_int))


        # add to the datasets for the model
        model1_x_train = shared_examples_x_train + model1_x_train + model1_examples_x_train
        model1_y_train = shared_examples_y_train + model1_y_train + model1_examples_y_train

        model2_x_train = shared_examples_x_train + model2_x_train + model2_examples_x_train
        model2_y_train = shared_examples_y_train + model2_y_train + model2_examples_y_train

    #print(model1_y_train)

    # assign mapping for new classes
    model1_class_mapping = {}
    model2_class_mapping = {}

    model1_classes_inc = 0
    # go through model1 and assign unique classes to incrimental int starting at 0
    for index in range(len(model1_y_train)):
        # if it doesn't exist assign
        if model1_y_train[index] not in model1_class_mapping.keys():
            model1_class_mapping[model1_y_train[index]] = model1_classes_inc
            model1_classes_inc += 1
        # append assigned token
        model1_y_train[index] = model1_class_mapping[model1_y_train[index]]


    model2_classes_inc = 0
    # go through model2 and assign unique classes to incrimental int starting at 0
    for index in range(len(model2_y_train)):
        # if it doesn't exist in model2 OR in model1, assign it
        if model2_y_train[index] not in model2_class_mapping.keys() and model2_y_train[index] not in model1_class_mapping.keys():
            model2_class_mapping[model2_y_train[index]] = model2_classes_inc
            model2_y_train[index] = model2_classes_inc
            model2_classes_inc += 1
        elif model2_y_train[index] in model1_class_mapping.keys():
            model2_y_train[index] = model1_class_mapping[model2_y_train[index]]
        else:
            model2_y_train[index] = model2_class_mapping[model2_y_train[index]]

    model1_x_test = []
    model1_y_test = []

    model2_x_test = []
    model2_y_test = []

    shared_x_test = []
    shared_y_test = []


    # test data splits
    for index in range(len(test_data)):

        current_class = test_data[index][1]

        # model 1
        if current_class in model1_classes:
            model1_x_test.append(test_data[index][0])
            model1_y_test.append(test_data[index][1])

        # model 2
        if current_class in model2_classes:
            model2_x_test.append(test_data[index][0])
            model2_y_test.append(test_data[index][1])

        # shared classes for eval
        if current_class in shared_classes:
            shared_x_test.append(test_data[index][0])
            shared_y_test.append(test_data[index][1])

    model1_x_test += shared_x_test
    model1_y_test += shared_y_test

    model2_x_test += shared_x_test
    model2_y_test += shared_y_test


    for index in range(len(model1_y_test)):
        model1_y_test[index] = model1_class_mapping[model1_y_test[index]]


    for index in range(len(model2_y_test)):
        if model2_y_test[index] in model1_class_mapping.keys():
            model2_y_test[index] = model1_class_mapping[model2_y_test[index]]
        else:
            model2_y_test[index] = model2_class_mapping[model2_y_test[index]]


    model1_classes_len= len(set([item for item in model1_y_train]))
    model2_classes_len = len(set([item for item in model2_y_train]))


    if task.upper() == "CIFAR100":

        model1 = models.wide_resnet50_2()
        model2 = models.wide_resnet50_2()
        #
        model1.fc = nn.Linear(2048, model1_classes_len)
        model2.fc = nn.Linear(2048, model2_classes_len)

    elif task.upper() == "IMAGENET":
        model1 = models.wide_resnet50_2()
        model2 = models.wide_resnet50_2()

        model1.fc = nn.Linear(2048, model1_classes_len)
        model2.fc = nn.Linear(2048, model2_classes_len)
    elif task.upper() == "FASHIONMNIST":
        model1 = models.resnet18()
        model2 = models.resnet18()


        model1.fc = nn.Linear(512, model1_classes_len)
        model2.fc = nn.Linear(512, model2_classes_len)

    else:
        # Get model (using ResNet50 for now)
        model1 = models.resnet50()
        model2 = models.resnet50()

        model1.fc = nn.Linear(2048, model1_classes_len)
        model2.fc = nn.Linear(2048, model2_classes_len)


    cuda = torch.cuda.is_available()
    if gpu_num in range(torch.cuda.device_count()):
        device = torch.device('cuda:'+str(gpu_num) if cuda else 'cpu')
        torch.cuda.set_device(device)
    else:
        device = torch.device('cpu')

    # Model Training

    model1 = model1.to(device)
    model2 = model2.to(device)

    criterion1 = nn.CrossEntropyLoss()
    optimizer1 = optim.AdamW(model1.parameters(), lr=learning_rate)
    scheduler1 = optim.lr_scheduler.MultiStepLR(optimizer1,milestones=[60, 120, 160], gamma=.2) #learning rate decay


    criterion2 = nn.CrossEntropyLoss()
    optimizer2 = optim.AdamW(model2.parameters(), lr=learning_rate)
    scheduler2 = optim.lr_scheduler.MultiStepLR(optimizer2,milestones=[60, 120, 160], gamma=.2) #learning rate decay

    # zip together two lists
    train_set1 = list(zip(model1_x_train, model1_y_train))

    # create trainloader 1
    trainloader_1 = torch.utils.data.DataLoader(train_set1, batch_size=batch_size,
                                              shuffle=True, num_workers=2)
    # create trainloader 2

    # zip together two lists
    train_set2 = list(zip(model2_x_train, model2_y_train))

    # create trainloader 1
    trainloader_2 = torch.utils.data.DataLoader(train_set2, batch_size=batch_size,
                                              shuffle=True, num_workers=2)


    # TODO change this
    num_adv_batchs = 2 if adv_training else 0

    adv_batches = random.sample(range(len(trainloader_1)), num_adv_batchs)

    #print("adv_batches:", adv_batches)

    # train model 1
    for epoch in tqdm(range(n_epochs),desc="Epoch"):  # loop over the dataset multiple times

        running_loss = 0.0
        for i, data in enumerate(trainloader_1, 0):
            if cuda:
                data = tuple(d.cuda() for d in data)


            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer1.zero_grad()

            # forward + backward + optimize

            # train adversarial
    #         if i in adv_batches:
    #             print("adv training!")
    #             adversary = LinfPGDAttack(
    #                 model1, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=eps,
    #                 nb_iter=adv_steps, eps_iter=0.01, rand_init=True, clip_min=0.0, clip_max=1.0,
    #                 targeted=False)
    #             inputs = adversary.perturb(inputs, labels)


            outputs = model1(inputs)
            loss = criterion1(outputs, labels)
            loss.backward()
            optimizer1.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i  + 1, running_loss / 2000))
                running_loss = 0.0

    print('Finished Training model1')

    # train model 2
    for epoch in tqdm(range(n_epochs),desc="Epoch"):  # loop over the dataset multiple times

        running_loss = 0.0
        for i, data in enumerate(trainloader_2, 0):
            if cuda:
                data = tuple(d.cuda() for d in data)

            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer2.zero_grad()

            # forward + backward + optimize
            outputs = model2(inputs)
            loss = criterion2(outputs, labels)
            loss.backward()
            optimizer2.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0

    print('Finished Training model2')

    model1 = model1.to("cpu")
    model2 = model2.to("cpu")


    # convert shared classes to new labels
    for index in range(len(shared_y_test)):
        if shared_y_test[index] in model1_class_mapping.keys():
            shared_y_test[index] = model1_class_mapping[shared_y_test[index]]
        else:
            shared_y_test[index] = model2_class_mapping[shared_y_test[index]]


    shared_y_test = torch.Tensor(shared_y_test).long()


    # if cuda:
    #     shared_x_test = tuple(d.cuda() for d in shared_x_test)
    #     shared_y_test = torch.Tensor(shared_y_test).long().cuda()

    model1_x_test = torch.stack(model1_x_test)
    model2_x_test = torch.stack(model2_x_test)

    model1.eval()

    shared_x_test = torch.stack(shared_x_test)

    model1.eval()

    adversary = LinfPGDAttack(
        model1, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=eps,
        nb_iter=adv_steps, eps_iter=0.01, rand_init=True, clip_min=0.0, clip_max=1.0,
        targeted=False)

    adv_untargeted = adversary.perturb(shared_x_test, shared_y_test)

    timestr = time.strftime("%Y%m%d_%H%M%S")

    print("saving models at", timestr)

    model1_name = './models/{}_{}_{}_model1_{}.pickle'.format(task,num_shared_classes, percent_shared_data,timestr)
    model2_name = './models/{}_{}_{}_model2_{}.pickle'.format(task,num_shared_classes, percent_shared_data,timestr)
    adv_name = './models/{}_{}_{}_adv_{}.pickle'.format(task,num_shared_classes, percent_shared_data,timestr)


    torch.save(model1, model1_name)
    torch.save(model2, model2_name)
    torch.save(adversary, adv_name)

    #  Eval

    with torch.no_grad():
        model1.eval()
        model2.eval()

        # model1 outputs

        output1 = model1(model1_x_test)
        shared_output1 = model1(shared_x_test)
        adv_output1 = model1(adv_untargeted)

        # model2 outputs
        output2 = model2(model2_x_test)
        shared_output2 = model2(shared_x_test)
        adv_output2 = model2(adv_untargeted)

        if task.upper() == "CIFAR100":

            # model 1

            print("model1_acc:", accuracy(output1,model1_y_test))

            print("model1_acc_5:", accuracy_n(output1,model1_y_test,5))

            print("model1_acc_shared:", accuracy(shared_output1,shared_y_test))
            print("model1_acc_5_shared:", accuracy_n(shared_output1,shared_y_test,5))

            print("model1_adv_acc_shared:", accuracy(adv_output1,shared_y_test))
            print("model1_adv_acc_5_shared:", accuracy_n(adv_output1,shared_y_test,5))

            print()

            # model 2

            print("model2_acc:", accuracy(output2,model2_y_test))
            print("model2_acc_5:", accuracy_n(output2,model2_y_test,5))

            print("model2_acc_shared:", accuracy(shared_output2,shared_y_test))
            print("model2_acc_5_shared:", accuracy_n(shared_output2,shared_y_test,5))

            print("model2_adv_acc_shared:", accuracy(adv_output2,shared_y_test))
            print("model2_adv_acc_5_shared:", accuracy_n(adv_output2,shared_y_test,5))

        else:
             # model 1

            print("model1_acc:", accuracy(output1,model1_y_test))

            print("model1_acc_shared:", accuracy(shared_output1,shared_y_test))

            print("model1_adv_acc_shared:", accuracy(adv_output1,shared_y_test))
            print()

            # model 2

            print("model2_acc:", accuracy(output2,model2_y_test))

            print("model2_acc_shared:", accuracy(shared_output2,shared_y_test))

            print("model2_adv_acc_shared:", accuracy(adv_output2,shared_y_test))
# -*- coding: utf-8 -*-

import torch
from torchvision import models
from torchsummary import summary

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = models.wide_resnet50_2().to(device)

summary(model, (3, 224, 224))
Exemplo n.º 18
0
def get_models(config):
    model_name = config['model_name']
    input_size = 224

    if 'vgg' in model_name:
        if model_name == 'vgg16':
            model = models.vgg16(pretrained=True)
        elif model_name == 'vgg19':
            model = models.vgg19(pretrained=True)
        elif model_name == 'vgg16_bn':
            model = models.vgg16_bn(pretrained=True)
        elif model_name == 'vgg19_bn':
            model = models.vgg19_bn(pretrained=True)
        for p in model.parameters():
            p.requires_grad = False
        #model.classifier[0] = nn.Linear(in_features=25088, out_features=4096)
        #model.classifier[3] = nn.Linear(in_features=4096, out_features=4096)
        model.classifier[6] = nn.Linear(in_features=4096, out_features=config['num_used_classes'])

    elif 'res' in model_name:
        if model_name == 'resnet50':
            model = models.resnet50(pretrained=True)
        elif model_name == 'resnet101':
            model = models.resnet101(pretrained=True)
        elif model_name == 'resnet152':
            model = models.resnet152(pretrained=True)
        elif model_name == 'resnext50_32x8d':
            model = models.resnext50_32x4d(pretrained=True)
        elif model_name == 'resnext101_32x8d':
            model = models.resnext101_32x8d(pretrained=True)
        elif model_name == 'wide_resnet50_2':
            model = models.wide_resnet50_2(pretrained=True)
        elif model_name == 'wide_resnet101_2':
            model = models.wide_resnet101_2(pretrained=True)
        for p in model.parameters():
            p.requires_grad = False
        
        model.fc = nn.Linear(in_features=2048, out_features=config['num_used_classes'])

    elif 'inception' in model_name:
        if model_name == 'inception_v3':
            model = models.inception_v3(pretrained=True)

        for p in model.parameters():
            p.requires_grad = False

        num_ftrs = model.AuxLogits.fc.in_features
        model.AuxLogits.fc = nn.Linear(num_ftrs, config['num_used_classes'])
        # Handle the primary net
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, config['num_used_classes'])
        input_size = 299

    elif 'densenet' in model_name:
        if model_name == 'densenet121':
            model = models.densenet121(pretrained=True)
        if model_name == 'densenet161':
            model = models.densenet161(pretrained=True)
        if model_name == 'densenet169':
            model = models.densenet169(pretrained=True)
        if model_name == 'densenet201':
            model = models.densenet201(pretrained=True)
        for p in model.parameters():
            p.requires_grad = False
        num_ftrs = model.classifier.in_features
        model.classifier = nn.Linear(num_ftrs, config['num_used_classes'])

    return model, input_size
Exemplo n.º 19
0
print('Max iteration:', max_iteration)
print('Epochs:', epochs)
print('Batch size:', bs)

# build dataset
train_dataset = Dataloader(x_train, y_train, data_transforms['train'])
valid_dataset = Dataloader(x_val, y_val, data_transforms['val'])
# build mini dataset
mini_train = Dataloader(x_train, y_train, mini_data_transforms['train'])
mini_valid = Dataloader(x_val, y_val, mini_data_transforms['val'])

# densenet201
# cnn = models.densenet201(pretrained=True)
# cnn.classifier = nn.Linear(1920, 8)
# resnet50
cnn = models.wide_resnet50_2(pretrained=True)
cnn.fc = nn.Linear(2048, 8)
cnn.to(device)


from torch.utils.data import DataLoader as torch_DataLoader
from sampler import ImbalancedDatasetSampler

# define training and validation data loaders
train_dataloader = torch_DataLoader(train_dataset, batch_size=bs, sampler=ImbalancedDatasetSampler(train_dataset),
                                    shuffle=False)
valid_dataloader = torch_DataLoader(valid_dataset, batch_size=bs, shuffle=False)

# define mini train and valid data loaders
mini_train_dataloader = torch_DataLoader(mini_train, batch_size=bs, sampler=ImbalancedDatasetSampler(train_dataset),
                                    shuffle=False)
    def __init__(self, backbone, num_classes, pretrained=False):
        super().__init__()
        if backbone == 'resnet-18':
            resnet = models.resnet18(pretrained=pretrained)
            filters = [64, 64, 128, 256, 512]
        elif backbone == 'resnet-34':
            resnet = models.resnet34(pretrained=pretrained)
            filters = [64, 64, 128, 256, 512]
        elif backbone == 'resnet-50':
            resnet = models.resnet50(pretrained=pretrained)
            filters = [64, 256, 512, 1024, 2048]
        elif backbone == 'resnet-101':
            resnet = models.resnet101(pretrained=pretrained)
            filters = [64, 256, 512, 1024, 2048]
        elif backbone == 'resnet-152':
            resnet = models.resnet152(pretrained=pretrained)
            filters = [64, 256, 512, 1024, 2048]
        elif backbone == 'resnext50_32x4d':
            resnet = models.resnext50_32x4d(pretrained=pretrained)
            filters = [64, 256, 512, 1024, 2048]
        elif backbone == 'resnext101_32x8d':
            resnet = models.resnext101_32x8d(pretrained=pretrained)
            filters = [64, 256, 512, 1024, 2048]
        elif backbone == 'wide_resnet50_2':
            resnet = models.wide_resnet50_2(pretrained=pretrained)
            filters = [64, 256, 512, 1024, 2048]
        elif backbone == 'wide_resnet101_2':
            resnet = models.wide_resnet101_2(pretrained=pretrained)
            filters = [64, 256, 512, 1024, 2048]
        else:
            raise NotImplementedError

        # Encoder
        self.conv = nn.Conv2d(in_channels=3,
                              out_channels=64,
                              kernel_size=7,
                              stride=1,
                              padding=3,
                              bias=False)
        self.bn = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool

        self.encoder1 = resnet.layer1
        self.encoder2 = resnet.layer2
        self.encoder3 = resnet.layer3
        self.encoder4 = resnet.layer4

        # Decoder
        self.decoder1 = Decoder(in_channels=filters[1] * 2,
                                mid_channels=128,
                                out_channels=filters[0])
        self.decoder2 = Decoder(in_channels=filters[2] * 2,
                                mid_channels=256,
                                out_channels=filters[1])
        self.decoder3 = Decoder(in_channels=filters[3] * 2,
                                mid_channels=512,
                                out_channels=filters[2])
        self.decoder4 = Decoder(in_channels=filters[4],
                                mid_channels=1024,
                                out_channels=filters[3])

        self.final = nn.Sequential(
            nn.Conv2d(in_channels=filters[0] * 2,
                      out_channels=filters[0],
                      kernel_size=3,
                      padding=1),
            nn.BatchNorm2d(filters[0]),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=filters[0],
                      out_channels=filters[0],
                      kernel_size=3,
                      padding=1),
            nn.BatchNorm2d(filters[0]),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=filters[0],
                      out_channels=num_classes,
                      kernel_size=1),
            nn.Sigmoid(),
        )

        # Spatial transformer localization-network
        self.localization = nn.Sequential(nn.Conv2d(3, 8, kernel_size=7),
                                          nn.MaxPool2d(2, stride=2),
                                          nn.ReLU(True),
                                          nn.Conv2d(8, 10, kernel_size=5),
                                          nn.MaxPool2d(2, stride=2),
                                          nn.ReLU(True))

        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(nn.Linear(10 * 52 * 52, 32), nn.ReLU(True),
                                    nn.Linear(32, 3 * 2))

        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(
            torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

        self._init_weight()
Exemplo n.º 21
0
            stats.append(stat)
        experiments.printStats(stats, metadata)

    except Exception as ex:
        experiments.printException(ex, types)

    '''
    #####################
    types = ('predefModel', 'CIFAR10', 'movingMean')
    modelName = "wide_resnet"
    try:
        stats = []
        rootFolder = sf.Output.getTimeStr() + ''.join(x + "_"
                                                      for x in types) + "set"
        for r in range(loop):
            obj = models.wide_resnet50_2()
            metadata.resetOutput()
            dataMetadata = dc.DefaultData_Metadata(pin_memoryTest=True,
                                                   pin_memoryTrain=True,
                                                   epoch=15,
                                                   fromGrayToRGB=False)

            smoothingMetadata.movingAvgParam = 0.15
            modelMetadata = dc.DefaultModel_Metadata()

            stat = dc.run(numbOfRepetition=1,
                          modelType=types[0],
                          dataType=types[1],
                          smoothingType=types[2],
                          metadataObj=metadata,
                          modelMetadata=modelMetadata,
Exemplo n.º 22
0
def main():

    args = parse_args()

    # load model
    if args.arch == 'resnet18':
        model = resnet18(pretrained=True, progress=True)
        t_d = 448
        d = 100
    elif args.arch == 'wide_resnet50_2':
        model = wide_resnet50_2(pretrained=True, progress=True)
        t_d = 1792
        d = 550
    model.to(device)
    model.eval()
    random.seed(1024)
    torch.manual_seed(1024)
    if use_cuda:
        torch.cuda.manual_seed_all(1024)

    idx = torch.tensor(sample(range(0, t_d), d))

    # set model's intermediate outputs
    outputs = []

    def hook(module, input, output):
        outputs.append(output)

    model.layer1[-1].register_forward_hook(hook)
    model.layer2[-1].register_forward_hook(hook)
    model.layer3[-1].register_forward_hook(hook)

    os.makedirs(os.path.join(args.save_path, 'temp_%s' % args.arch),
                exist_ok=True)
    fig, ax = plt.subplots(1, 2, figsize=(20, 10))
    fig_img_rocauc = ax[0]
    fig_pixel_rocauc = ax[1]

    total_roc_auc = []
    total_pixel_roc_auc = []

    class_name = args.class_name

    train_dataset = mvtec.MVTecDataset(args.data_path,
                                       class_name=class_name,
                                       is_train=True)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=32,
                                  pin_memory=True)
    test_dataset = mvtec.MVTecDataset(args.data_path,
                                      class_name=class_name,
                                      is_train=False)
    test_dataloader = DataLoader(test_dataset, batch_size=32, pin_memory=True)

    train_outputs = OrderedDict([('layer1', []), ('layer2', []),
                                 ('layer3', [])])
    test_outputs = OrderedDict([('layer1', []), ('layer2', []),
                                ('layer3', [])])

    # extract train set features
    train_feature_filepath = os.path.join(args.save_path,
                                          'temp_%s' % args.arch,
                                          'train_%s.pkl' % class_name)
    if not os.path.exists(train_feature_filepath):
        for (x, _,
             _) in tqdm(train_dataloader,
                        '| feature extraction | train | %s |' % class_name):
            # model prediction
            with torch.no_grad():
                _ = model(x.to(device))
            # get intermediate layer outputs
            for k, v in zip(train_outputs.keys(), outputs):
                train_outputs[k].append(v.cpu().detach())
            # initialize hook outputs
            outputs = []
        for k, v in train_outputs.items():
            train_outputs[k] = torch.cat(v, 0)

        # Embedding concat
        embedding_vectors = train_outputs['layer1']
        for layer_name in ['layer2', 'layer3']:
            embedding_vectors = embedding_concat(embedding_vectors,
                                                 train_outputs[layer_name])

        # randomly select d dimension
        embedding_vectors = torch.index_select(embedding_vectors, 1, idx)
        # calculate multivariate Gaussian distribution
        B, C, H, W = embedding_vectors.size()
        embedding_vectors = embedding_vectors.view(B, C, H * W)
        mean = torch.mean(embedding_vectors, dim=0).numpy()
        cov = torch.zeros(C, C, H * W).numpy()
        I = np.identity(C)
        for i in range(H * W):
            # cov[:, :, i] = LedoitWolf().fit(embedding_vectors[:, :, i].numpy()).covariance_
            cov[:, :, i] = np.cov(embedding_vectors[:, :, i].numpy(),
                                  rowvar=False) + 0.01 * I
        # save learned distribution
        train_outputs = [mean, cov]
        with open(train_feature_filepath, 'wb') as f:
            pickle.dump(train_outputs, f)
    else:
        print('load train set feature from: %s' % train_feature_filepath)
        with open(train_feature_filepath, 'rb') as f:
            train_outputs = pickle.load(f)

    gt_list = []
    gt_mask_list = []
    test_imgs = []

    # extract test set features
    for (x, y,
         mask) in tqdm(test_dataloader,
                       '| feature extraction | test | %s |' % class_name):
        test_imgs.extend(x.cpu().detach().numpy())
        gt_list.extend(y.cpu().detach().numpy())
        gt_mask_list.extend(mask.cpu().detach().numpy())
        # model prediction
        with torch.no_grad():
            _ = model(x.to(device))
        # get intermediate layer outputs
        for k, v in zip(test_outputs.keys(), outputs):
            test_outputs[k].append(v.cpu().detach())
        # initialize hook outputs
        outputs = []
    for k, v in test_outputs.items():
        test_outputs[k] = torch.cat(v, 0)

    # Embedding concat
    embedding_vectors = test_outputs['layer1']
    for layer_name in ['layer2', 'layer3']:
        embedding_vectors = embedding_concat(embedding_vectors,
                                             test_outputs[layer_name])

    # randomly select d dimension
    embedding_vectors = torch.index_select(embedding_vectors, 1, idx)

    # calculate distance matrix
    B, C, H, W = embedding_vectors.size()
    embedding_vectors = embedding_vectors.view(B, C, H * W).numpy()
    dist_list = []
    for i in range(H * W):
        mean = train_outputs[0][:, i]
        conv_inv = np.linalg.inv(train_outputs[1][:, :, i])
        dist = [
            mahalanobis(sample[:, i], mean, conv_inv)
            for sample in embedding_vectors
        ]
        dist_list.append(dist)

    dist_list = np.array(dist_list).transpose(1, 0).reshape(B, H, W)

    # upsample
    dist_list = torch.tensor(dist_list)
    score_map = F.interpolate(dist_list.unsqueeze(1),
                              size=x.size(2),
                              mode='bilinear',
                              align_corners=False).squeeze().numpy()

    # apply gaussian smoothing on the score map
    for i in range(score_map.shape[0]):
        score_map[i] = gaussian_filter(score_map[i], sigma=4)

    # Normalization
    max_score = score_map.max()
    min_score = score_map.min()
    scores = (score_map - min_score) / (max_score - min_score)

    # calculate image-level ROC AUC score
    img_scores = scores.reshape(scores.shape[0], -1).max(axis=1)
    gt_list = np.asarray(gt_list)
    fpr, tpr, _ = roc_curve(gt_list, img_scores)
    img_roc_auc = roc_auc_score(gt_list, img_scores)
    total_roc_auc.append(img_roc_auc)
    print('image ROCAUC: %.3f' % (img_roc_auc))
    fig_img_rocauc.plot(fpr,
                        tpr,
                        label='%s img_ROCAUC: %.3f' %
                        (class_name, img_roc_auc))

    # get optimal threshold
    gt_mask = np.asarray(gt_mask_list)
    precision, recall, thresholds = precision_recall_curve(
        gt_mask.flatten(), scores.flatten())
    a = 2 * precision * recall
    b = precision + recall
    f1 = np.divide(a, b, out=np.zeros_like(a), where=b != 0)
    threshold = thresholds[np.argmax(f1)]

    # calculate per-pixel level ROCAUC
    fpr, tpr, _ = roc_curve(gt_mask.flatten(), scores.flatten())
    per_pixel_rocauc = roc_auc_score(gt_mask.flatten(), scores.flatten())
    total_pixel_roc_auc.append(per_pixel_rocauc)
    print('pixel ROCAUC: %.3f' % (per_pixel_rocauc))

    fig_pixel_rocauc.plot(fpr,
                          tpr,
                          label='%s ROCAUC: %.3f' %
                          (class_name, per_pixel_rocauc))
    save_dir = args.save_path + '/' + f'pictures_{args.arch}'
    os.makedirs(save_dir, exist_ok=True)
    plot_fig(test_imgs, scores, gt_mask_list, threshold, save_dir, class_name)

    print('Average ROCAUC: %.3f' % np.mean(total_roc_auc))
    fig_img_rocauc.title.set_text('Average image ROCAUC: %.3f' %
                                  np.mean(total_roc_auc))
    fig_img_rocauc.legend(loc="lower right")

    print('Average pixel ROCUAC: %.3f' % np.mean(total_pixel_roc_auc))
    fig_pixel_rocauc.title.set_text('Average pixel ROCAUC: %.3f' %
                                    np.mean(total_pixel_roc_auc))
    fig_pixel_rocauc.legend(loc="lower right")

    fig.tight_layout()
    fig.savefig(os.path.join(args.save_path, 'roc_curve.png'), dpi=100)
Exemplo n.º 23
0
    def __init__(self,
                 input_channels,
                 num_classes,
                 pretrained=True,
                 skip=True,
                 hidden_classes=None):

        super(FCNWideResNet50, self).__init__()

        self.skip = skip

        # ResNet-50 with Skip Connections (adapted from FCN-8s).
        wideresnet = models.wide_resnet50_2(pretrained=pretrained,
                                            progress=False)

        if pretrained:
            self.init = nn.Sequential(wideresnet.conv1, wideresnet.bn1,
                                      wideresnet.relu, wideresnet.maxpool)
        else:
            self.init = nn.Sequential(
                nn.Conv2d(input_channels,
                          64,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          bias=False), wideresnet.bn1, wideresnet.relu,
                wideresnet.maxpool)
        self.layer1 = wideresnet.layer1
        self.layer2 = wideresnet.layer2
        self.layer3 = wideresnet.layer3
        self.layer4 = wideresnet.layer4

        if self.skip:

            if hidden_classes is None:
                self.classifier1 = nn.Sequential(
                    nn.Conv2d(2048 + 256, 64, kernel_size=3, padding=1),
                    nn.BatchNorm2d(64),
                    nn.ReLU(),
                    nn.Dropout2d(0.5),
                )
                self.final = nn.Conv2d(64,
                                       num_classes,
                                       kernel_size=3,
                                       padding=1)
            else:
                self.classifier1 = nn.Sequential(
                    nn.Conv2d(2048 + 256, 64, kernel_size=3, padding=1),
                    nn.BatchNorm2d(64),
                    nn.ReLU(),
                    nn.Dropout2d(0.5),
                )
                self.final = nn.Conv2d(64,
                                       num_classes - len(hidden_classes),
                                       kernel_size=3,
                                       padding=1)

        else:

            if hidden_classes is None:
                self.classifier1 = nn.Sequential(
                    nn.Conv2d(2048, 64, kernel_size=3, padding=1),
                    nn.BatchNorm2d(64),
                    nn.ReLU(),
                    nn.Dropout2d(0.5),
                )
                self.final = nn.Conv2d(64,
                                       num_classes,
                                       kernel_size=3,
                                       padding=1)
            else:
                self.classifier1 = nn.Sequential(
                    nn.Conv2d(2048, 64, kernel_size=3, padding=1),
                    nn.BatchNorm2d(64),
                    nn.ReLU(),
                    nn.Dropout2d(0.5),
                )
                self.final = nn.Conv2d(64,
                                       num_classes - len(hidden_classes),
                                       kernel_size=3,
                                       padding=1)

        if not pretrained:
            initialize_weights(self)
        else:
            initialize_weights(self.classifier1)
            initialize_weights(self.final)
Exemplo n.º 24
0
def initialize_model(model_name,
                     num_classes,
                     feature_extract,
                     use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "resnet":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
    elif model_name == "wide_resnet50":
        """ wide_resnet50_2
        """
        model_ft = models.wide_resnet50_2(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
    elif model_name == "wide_resnet101":
        """ wide_resnet101
        """
        model_ft = models.wide_resnet101_2(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
    elif model_name == "resnext50":
        """ resnext50_32x4d
        """
        model_ft = models.resnext50_32x4d(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
    elif model_name == "resnext101":
        """ resnext101_32x8d
        """
        model_ft = models.resnext101_32x8d(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224
    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512,
                                           num_classes,
                                           kernel_size=(1, 1),
                                           stride=(1, 1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)

    elif model_name == "inception":
        """ Inception v3 
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)

    else:
        print("Invalid model name, exiting...")
    return model_ft
Exemplo n.º 25
0
    def __init__(self,
                 num_classes,
                 trunk='resnet-101',
                 criterion=None,
                 criterion_aux=None,
                 variant='D',
                 skip='m1',
                 skip_num=48,
                 args=None):
        super(DeepV3PlusHANet, self).__init__()
        self.criterion = criterion
        self.criterion_aux = criterion_aux
        self.variant = variant
        self.args = args
        self.num_attention_layer = 0
        self.trunk = trunk

        for i in range(5):
            if args.hanet[i] > 0:
                self.num_attention_layer += 1

        print("#### HANet layers", self.num_attention_layer)

        if trunk == 'shufflenetv2':
            channel_1st = 3
            channel_2nd = 24
            channel_3rd = 116
            channel_4th = 232
            prev_final_channel = 464
            final_channel = 1024
            resnet = models.shufflenet_v2_x1_0(pretrained=True)
            self.layer0 = nn.Sequential(resnet.conv1, resnet.maxpool)
            self.layer1 = resnet.stage2
            self.layer2 = resnet.stage3
            self.layer3 = resnet.stage4
            self.layer4 = resnet.conv5

            if self.variant == 'D':
                for n, m in self.layer2.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1,
                                                                           1)
            elif self.variant == 'D16':
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
            else:
                # raise 'unknown deepv3 variant: {}'.format(self.variant)
                print("Not using Dilation ")
        elif trunk == 'mnasnet_05' or trunk == 'mnasnet_10':

            if trunk == 'mnasnet_05':
                resnet = models.mnasnet0_5(pretrained=True)
                channel_1st = 3
                channel_2nd = 16
                channel_3rd = 24
                channel_4th = 48
                prev_final_channel = 160
                final_channel = 1280

                print("# of layers", len(resnet.layers))
                self.layer0 = nn.Sequential(resnet.layers[0], resnet.layers[1],
                                            resnet.layers[2], resnet.layers[3],
                                            resnet.layers[4], resnet.layers[5],
                                            resnet.layers[6],
                                            resnet.layers[7])  # 16
                self.layer1 = nn.Sequential(resnet.layers[8],
                                            resnet.layers[9])  # 24, 40
                self.layer2 = nn.Sequential(resnet.layers[10],
                                            resnet.layers[11])  # 48, 96
                self.layer3 = nn.Sequential(resnet.layers[12],
                                            resnet.layers[13])  # 160, 320
                self.layer4 = nn.Sequential(resnet.layers[14],
                                            resnet.layers[15],
                                            resnet.layers[16])  # 1280
            else:
                resnet = models.mnasnet1_0(pretrained=True)
                channel_1st = 3
                channel_2nd = 16
                channel_3rd = 40
                channel_4th = 96
                prev_final_channel = 320
                final_channel = 1280

                print("# of layers", len(resnet.layers))
                self.layer0 = nn.Sequential(resnet.layers[0], resnet.layers[1],
                                            resnet.layers[2], resnet.layers[3],
                                            resnet.layers[4], resnet.layers[5],
                                            resnet.layers[6],
                                            resnet.layers[7])  # 16
                self.layer1 = nn.Sequential(resnet.layers[8],
                                            resnet.layers[9])  # 24, 40
                self.layer2 = nn.Sequential(resnet.layers[10],
                                            resnet.layers[11])  # 48, 96
                self.layer3 = nn.Sequential(resnet.layers[12],
                                            resnet.layers[13])  # 160, 320
                self.layer4 = nn.Sequential(resnet.layers[14],
                                            resnet.layers[15],
                                            resnet.layers[16])  # 1280

            if self.variant == 'D':
                for n, m in self.layer2.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1,
                                                                           1)
            elif self.variant == 'D16':
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
            else:
                # raise 'unknown deepv3 variant: {}'.format(self.variant)
                print("Not using Dilation ")
        elif trunk == 'mobilenetv2':
            channel_1st = 3
            channel_2nd = 16
            channel_3rd = 32
            channel_4th = 64

            # prev_final_channel = 160
            prev_final_channel = 320

            final_channel = 1280
            resnet = models.mobilenet_v2(pretrained=True)
            self.layer0 = nn.Sequential(resnet.features[0], resnet.features[1])
            self.layer1 = nn.Sequential(resnet.features[2], resnet.features[3],
                                        resnet.features[4], resnet.features[5],
                                        resnet.features[6])
            self.layer2 = nn.Sequential(resnet.features[7], resnet.features[8],
                                        resnet.features[9],
                                        resnet.features[10])

            # self.layer3 = nn.Sequential(resnet.features[11], resnet.features[12], resnet.features[13], resnet.features[14], resnet.features[15], resnet.features[16])
            # self.layer4 = nn.Sequential(resnet.features[17], resnet.features[18])

            self.layer3 = nn.Sequential(
                resnet.features[11], resnet.features[12], resnet.features[13],
                resnet.features[14], resnet.features[15], resnet.features[16],
                resnet.features[17])
            self.layer4 = nn.Sequential(resnet.features[18])

            if self.variant == 'D':
                for n, m in self.layer2.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1,
                                                                           1)
            elif self.variant == 'D16':
                for n, m in self.layer3.named_modules():
                    if isinstance(m, nn.Conv2d) and m.stride == (2, 2):
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
            else:
                # raise 'unknown deepv3 variant: {}'.format(self.variant)
                print("Not using Dilation ")
        else:
            channel_1st = 3
            channel_2nd = 64
            channel_3rd = 256
            channel_4th = 512
            prev_final_channel = 1024
            final_channel = 2048

            if trunk == 'resnet-50':
                resnet = Resnet.resnet50()
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                              resnet.relu, resnet.maxpool)
            elif trunk == 'resnet-101':  # three 3 X 3
                resnet = Resnet.resnet101()
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                              resnet.relu1, resnet.conv2,
                                              resnet.bn2, resnet.relu2,
                                              resnet.conv3, resnet.bn3,
                                              resnet.relu3, resnet.maxpool)
            elif trunk == 'resnet-152':
                resnet = Resnet.resnet152()
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                              resnet.relu, resnet.maxpool)
            elif trunk == 'resnext-50':
                resnet = models.resnext50_32x4d(pretrained=True)
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                              resnet.relu, resnet.maxpool)
            elif trunk == 'resnext-101':
                resnet = models.resnext101_32x8d(pretrained=True)
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                              resnet.relu, resnet.maxpool)
            elif trunk == 'wide_resnet-50':
                resnet = models.wide_resnet50_2(pretrained=True)
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                              resnet.relu, resnet.maxpool)
            elif trunk == 'wide_resnet-101':
                resnet = models.wide_resnet101_2(pretrained=True)
                resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                              resnet.relu, resnet.maxpool)
            else:
                raise ValueError("Not a valid network arch")

            self.layer0 = resnet.layer0
            self.layer1, self.layer2, self.layer3, self.layer4 = \
                resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

            if self.variant == 'D':
                for n, m in self.layer3.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
                for n, m in self.layer4.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1,
                                                                           1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
            elif self.variant == 'D4':
                for n, m in self.layer2.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
                for n, m in self.layer3.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1,
                                                                           1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
                for n, m in self.layer4.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (8, 8), (8, 8), (1,
                                                                           1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
            elif self.variant == 'D16':
                for n, m in self.layer4.named_modules():
                    if 'conv2' in n:
                        m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,
                                                                           1)
                    elif 'downsample.0' in n:
                        m.stride = (1, 1)
            else:
                # raise 'unknown deepv3 variant: {}'.format(self.variant)
                print("Not using Dilation ")

        if self.variant == 'D':
            os = 8
        elif self.variant == 'D4':
            os = 4
        elif self.variant == 'D16':
            os = 16
        else:
            os = 32

        self.aspp = _AtrousSpatialPyramidPoolingModule(final_channel,
                                                       256,
                                                       output_stride=os)

        self.bot_fine = nn.Sequential(
            nn.Conv2d(channel_3rd, 48, kernel_size=1, bias=False), Norm2d(48),
            nn.ReLU(inplace=True))

        self.bot_aspp = nn.Sequential(
            nn.Conv2d(1280, 256, kernel_size=1, bias=False), Norm2d(256),
            nn.ReLU(inplace=True))

        self.final1 = nn.Sequential(
            nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True))

        self.final2 = nn.Sequential(
            nn.Conv2d(256, num_classes, kernel_size=1, bias=True))

        if self.args.aux_loss is True:
            self.dsn = nn.Sequential(
                nn.Conv2d(prev_final_channel,
                          512,
                          kernel_size=3,
                          stride=1,
                          padding=1), Norm2d(512), nn.ReLU(inplace=True),
                nn.Dropout2d(0.1),
                nn.Conv2d(512,
                          num_classes,
                          kernel_size=1,
                          stride=1,
                          padding=0,
                          bias=True))
            initialize_weights(self.dsn)

        if self.args.hanet[0] == 1:
            self.hanet0 = HANet_Conv(prev_final_channel,
                                     final_channel,
                                     self.args.hanet_set[0],
                                     self.args.hanet_set[1],
                                     self.args.hanet_set[2],
                                     self.args.hanet_pos[0],
                                     self.args.hanet_pos[1],
                                     pos_rfactor=self.args.pos_rfactor,
                                     pooling=self.args.pooling,
                                     dropout_prob=self.args.dropout,
                                     pos_noise=self.args.pos_noise)
            initialize_weights(self.hanet0)

        if self.args.hanet[1] == 1:
            self.hanet1 = HANet_Conv(final_channel,
                                     1280,
                                     self.args.hanet_set[0],
                                     self.args.hanet_set[1],
                                     self.args.hanet_set[2],
                                     self.args.hanet_pos[0],
                                     self.args.hanet_pos[1],
                                     pos_rfactor=self.args.pos_rfactor,
                                     pooling=self.args.pooling,
                                     dropout_prob=self.args.dropout,
                                     pos_noise=self.args.pos_noise)
            initialize_weights(self.hanet1)

        if self.args.hanet[2] == 1:
            self.hanet2 = HANet_Conv(1280,
                                     256,
                                     self.args.hanet_set[0],
                                     self.args.hanet_set[1],
                                     self.args.hanet_set[2],
                                     self.args.hanet_pos[0],
                                     self.args.hanet_pos[1],
                                     pos_rfactor=self.args.pos_rfactor,
                                     pooling=self.args.pooling,
                                     dropout_prob=self.args.dropout,
                                     pos_noise=self.args.pos_noise)
            initialize_weights(self.hanet2)

        if self.args.hanet[3] == 1:
            self.hanet3 = HANet_Conv(304,
                                     256,
                                     self.args.hanet_set[0],
                                     self.args.hanet_set[1],
                                     self.args.hanet_set[2],
                                     self.args.hanet_pos[0],
                                     self.args.hanet_pos[1],
                                     pos_rfactor=self.args.pos_rfactor,
                                     pooling=self.args.pooling,
                                     dropout_prob=self.args.dropout,
                                     pos_noise=self.args.pos_noise)
            initialize_weights(self.hanet3)

        if self.args.hanet[4] == 1:
            self.hanet4 = HANet_Conv(256,
                                     num_classes,
                                     self.args.hanet_set[0],
                                     self.args.hanet_set[1],
                                     self.args.hanet_set[2],
                                     self.args.hanet_pos[0],
                                     self.args.hanet_pos[1],
                                     pos_rfactor=self.args.pos_rfactor,
                                     pooling='max',
                                     dropout_prob=self.args.dropout,
                                     pos_noise=self.args.pos_noise)
            initialize_weights(self.hanet4)

        initialize_weights(self.aspp)
        initialize_weights(self.bot_aspp)
        initialize_weights(self.bot_fine)
        initialize_weights(self.final1)
        initialize_weights(self.final2)
 def __init__(self):
     super(WideResNet50V2E, self).__init__()
     self.model = models.wide_resnet50_2(pretrained=True)
     self.model.fc = nn.Linear(2048, 10)
Exemplo n.º 27
0
def create_model(model_type, num_classes, feature_extract, pretrained):
    """
    Creates a model.
    :param model_type: Model type.
    :param num_classes: Number of classes.
    :param feature_extract: A boolean indicating if we are extracting features.
    :param pretrained: A boolean indicating if pretrained weights should be used.
    :return: Model.
    """
    device = get_device()
    if 'resnet18' == model_type:
        model = models.resnet18(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif 'resnet34' == model_type:
        model = models.resnet34(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif 'resnet50' == model_type:
        model = models.resnet50(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif 'resnet101' == model_type:
        model = models.resnet101(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif 'resnet152' == model_type:
        model = models.resnet152(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif 'alexnet' == model_type:
        model = models.alexnet(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier[6] = nn.Linear(4096, num_classes)
    elif 'vgg11' == model_type:
        model = models.vgg11(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features,
                                        num_classes)
    elif 'vgg11_bn' == model_type:
        model = models.vgg11_bn(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features,
                                        num_classes)
    elif 'vgg13' == model_type:
        model = models.vgg13(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features,
                                        num_classes)
    elif 'vgg13_bn' == model_type:
        model = models.vgg13_bn(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features,
                                        num_classes)
    elif 'vgg16' == model_type:
        model = models.vgg16(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features,
                                        num_classes)
    elif 'vgg16_bn' == model_type:
        model = models.vgg16_bn(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features,
                                        num_classes)
    elif 'vgg19' == model_type:
        model = models.vgg19(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features,
                                        num_classes)
    elif 'vgg19_bn' == model_type:
        model = models.vgg19_bn(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features,
                                        num_classes)
    elif 'squeezenet1_0' == model_type:
        model = models.squeezenet1_0(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier[1] = nn.Conv2d(512,
                                        num_classes,
                                        kernel_size=(1, 1),
                                        stride=(1, 1))
        model.num_classes = num_classes
    elif 'squeezenet1_1' == model_type:
        model = models.squeezenet1_1(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier[1] = nn.Conv2d(512,
                                        num_classes,
                                        kernel_size=(1, 1),
                                        stride=(1, 1))
        model.num_classes = num_classes
    elif 'densenet121' == model_type:
        model = models.densenet121(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    elif 'densenet161' == model_type:
        model = models.densenet161(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    elif 'densenet169' == model_type:
        model = models.densenet169(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    elif 'densenet201' == model_type:
        model = models.densenet201(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    elif 'googlenet' == model_type:
        model = models.googlenet(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif 'shufflenet_v2_x0_5' == model_type:
        model = models.shufflenet_v2_x0_5(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif 'shufflenet_v2_x1_0' == model_type:
        model = models.shufflenet_v2_x1_0(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif 'mobilenet_v2' == model_type:
        model = models.mobilenet_v2(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features,
                                        num_classes)
    elif 'resnext50_32x4d' == model_type:
        model = models.resnext50_32x4d(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features,
                                        num_classes)
    elif 'resnext101_32x8d' == model_type:
        model = models.resnext101_32x8d(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif 'wide_resnet50_2' == model_type:
        model = models.wide_resnet50_2(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif 'wide_resnet101_2' == model_type:
        model = models.wide_resnet101_2(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif 'mnasnet0_5' == model_type:
        model = models.mnasnet0_5(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features,
                                        num_classes)
    elif 'mnasnet1_0' == model_type:
        model = models.mnasnet1_0(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features,
                                        num_classes)
    else:
        model = models.inception_v3(pretrained=pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.AuxLogits.fc = nn.Linear(model.AuxLogits.fc.in_features,
                                       num_classes)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model.to(device)
Exemplo n.º 28
0
def load_model(model_type):
    if model_type == "simclr":
        # load checkpoint for simclr
        checkpoint = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/resnet50-1x.pth')
        resnet = models.resnet50(pretrained=False)
        resnet.load_state_dict(checkpoint['state_dict'])
        # preprocess images for simclr
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(256),
            transforms.ToTensor()
        ])
        return resnet

    if model_type == "simclr_v2_0":
        # load checkpoint for simclr
        checkpoint = torch.load('/content/gdrive/MyDrive/r50_1x_sk0.pth')
        resnet = models.resnet50(pretrained=False)
        resnet.load_state_dict(checkpoint['resnet'])
        # preprocess images for simclr
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(256),
            transforms.ToTensor()
        ])
        return resnet
    if model_type == "moco":
        # load checkpoints of moco
        state_dict = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/moco_v1_200ep_pretrain.pth.tar',
            map_location=torch.device('cpu'))['state_dict']
        resnet = models.resnet50(pretrained=False)
        for k in list(state_dict.keys()):
            if k.startswith('module.encoder_q'
                            ) and not k.startswith('module.encoder_q.fc'):
                state_dict[k[len("module.encoder_q."):]] = state_dict[k]
            del state_dict[k]
        msg = resnet.load_state_dict(state_dict, strict=False)
        assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
        #preprocess for moco
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "mocov2":
        # load checkpoints of mocov2
        state_dict = torch.load(
            '/content/gdrive/MyDrive/moco/moco_v2_200ep_pretrain.pth.tar',
            map_location=torch.device('cpu'))['state_dict']
        resnet = models.resnet50(pretrained=False)
        for k in list(state_dict.keys()):
            if k.startswith('module.encoder_q'
                            ) and not k.startswith('module.encoder_q.fc'):
                state_dict[k[len("module.encoder_q."):]] = state_dict[k]
            del state_dict[k]
        msg = resnet.load_state_dict(state_dict, strict=False)
        assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
        #preprocess for mocov2
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "InsDis":
        # load checkpoints for instance recoginition resnet
        resnet = models.resnet50(pretrained=False)
        state_dict = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/lemniscate_resnet50_update.pth',
            map_location=torch.device('cpu'))['state_dict']
        for k in list(state_dict.keys()):
            if k.startswith('module') and not k.startswith('module.fc'):
                state_dict[k[len("module."):]] = state_dict[k]
            del state_dict[k]
        msg = resnet.load_state_dict(state_dict, strict=False)
        assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
        #preprocess for instance recoginition resnet
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "place365_rn50":
        # load checkpoints for place365 resnet
        resnet = models.resnet50(pretrained=False)
        state_dict = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/resnet50_places365.pth.tar',
            map_location=torch.device('cuda'))['state_dict']
        #     for k in list(state_dict.keys()):
        #         if k.startswith('module') and not k.startswith('module.fc'):
        #             state_dict[k[len("module."):]] = state_dict[k]
        #         del state_dict[k]
        msg = resnet.load_state_dict(state_dict, strict=False)
        #     assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
        #preprocess for place365-resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "resnext101":
        #load ResNeXt 101_32x8 imagenet trained model
        resnet = models.resnext101_32x8d(pretrained=True)
        #preprocess for resnext101
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "wsl_resnext101":
        # load wsl resnext101
        resnet = models.resnext101_32x8d(pretrained=False)
        checkpoint = torch.load(
            "/content/gdrive/MyDrive/model_checkpoints/ig_resnext101_32x8-c38310e5.pth"
        )
        resnet.load_state_dict(checkpoint)
        #preprocess for wsl resnext101
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "st_resnet":
        # load checkpoint for st resnet
        resnet = models.resnet50(pretrained=True)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "resnet101":
        # load checkpoint for st resnet
        resnet = models.resnet101(pretrained=True)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet
    if model_type == "wide_resnet101":
        # load checkpoint for st resnet
        resnet = models.wide_resnet101_2(pretrained=True)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet
    if model_type == "wide_resnet50":
        # load checkpoint for st resnet
        resnet = models.wide_resnet50_2(pretrained=True)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "untrained_resnet50":
        # load checkpoint for st resnet
        resnet = models.resnet50(pretrained=False)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "untrained_resnet101":
        # load checkpoint for st resnet
        resnet = models.resnet101(pretrained=False)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "untrained_wrn50":
        # load checkpoint for st resnet
        resnet = models.wide_resnet50_2(pretrained=False)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "untrained_wrn101":
        # load checkpoint for st resnet
        resnet = models.wide_resnet101_2(pretrained=False)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "st_alexnet":
        # load checkpoint for st alexnet
        alexnet = models.alexnet(pretrained=True)
        #preprocess for alexnet
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return alexnet

    if model_type == "clip":
        import clip
        resnet, preprocess = clip.load("RN50")
        return resnet

    if model_type == 'linf_8':
        #     resnet = torch.load('/content/gdrive/MyDrive/model_checkpoints/imagenet_linf_8_model.pt') # https://drive.google.com/file/d/1DRkIcM_671KQNhz1BIXMK6PQmHmrYy_-/view?usp=sharing
        #     preprocess = transforms.Compose([
        #     transforms.Resize(256),
        #     transforms.CenterCrop(224),
        #     transforms.ToTensor(),
        #     transforms.Normalize(
        #     mean=[0.485, 0.456, 0.406],
        #     std=[0.229, 0.224, 0.225])
        #     ])
        #     return resnet
        resnet = models.resnet50(pretrained=False)
        checkpoint = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/imagenet_linf_8.pt',
            map_location=torch.device('cuda'))
        state_dict = checkpoint['model']
        for k in list(state_dict.keys()):
            if k.startswith('module.attacker.model.'):

                state_dict[k[len('module.attacker.model.'):]] = state_dict[k]
            del state_dict[k]
        resnet.load_state_dict(state_dict)
        return resnet

    if model_type == 'linf_4':
        #     resnet = torch.load('/content/gdrive/MyDrive/model_checkpoints/robust_resnet.pt')#https://drive.google.com/file/d/1_tOhMBqaBpfOojcueSnYQRw_QgXdPVS6/view?usp=sharing
        #     preprocess = transforms.Compose([
        #     transforms.Resize(256),
        #     transforms.CenterCrop(224),
        #     transforms.ToTensor(),
        #     transforms.Normalize(
        #     mean=[0.485, 0.456, 0.406],
        #     std=[0.229, 0.224, 0.225])
        #     ])
        #     return resnet
        resnet = models.resnet50(pretrained=False)
        checkpoint = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/imagenet_linf_4.pt',
            map_location=torch.device('cuda'))
        state_dict = checkpoint['model']
        for k in list(state_dict.keys()):
            #         if k.startswith('module.attacker.model.') and not k.startswith('module.attacker.normalize') :
            if k.startswith('module.attacker.model.'):
                state_dict[k[len('module.attacker.model.'):]] = state_dict[k]
            del state_dict[k]
        resnet.load_state_dict(state_dict)
        return resnet

    if model_type == 'l2_3':
        #     resnet = torch.load('/content/gdrive/MyDrive/model_checkpoints/imagenet_l2_3_0_model.pt') # https://drive.google.com/file/d/1SM9wnNr_WnkEIo8se3qd3Di50SUT9apn/view?usp=sharing
        #     preprocess = transforms.Compose([
        #     transforms.Resize(256),
        #     transforms.CenterCrop(224),
        #     transforms.ToTensor(),
        #     transforms.Normalize(
        #     mean=[0.485, 0.456, 0.406],
        #     std=[0.229, 0.224, 0.225])
        #     ])
        #     return resnet
        resnet = models.resnet50(pretrained=False)
        checkpoint = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/imagenet_l2_3_0.pt',
            map_location=torch.device('cuda'))
        state_dict = checkpoint['model']
        for k in list(state_dict.keys()):
            if k.startswith('module.attacker.model.'):

                state_dict[k[len('module.attacker.model.'):]] = state_dict[k]
            del state_dict[k]
        resnet.load_state_dict(state_dict)
        return resnet

    if model_type == 'resnet50_l2_eps1' or model_type == 'resnet50_l2_eps0.01' or model_type == 'resnet50_l2_eps0.03' or model_type == 'resnet50_l2_eps0.5' or model_type == 'resnet50_l2_eps0.25' or model_type == 'resnet50_l2_eps3' or model_type == 'resnet50_l2_eps5':
        resnet = models.resnet50(pretrained=False)
        ds = ImageNet('/tmp')
        total_resnet, checkpoint = make_and_restore_model(
            arch='resnet50',
            dataset=ds,
            resume_path=
            f'/content/gdrive/MyDrive/model_checkpoints/{model_type}.ckpt')
        # resnet=total_resnet.attacker
        state_dict = checkpoint['model']
        for k in list(state_dict.keys()):
            if k.startswith('module.attacker.model.'):
                state_dict[k[len('module.attacker.model.'):]] = state_dict[k]
            del state_dict[k]
        resnet.load_state_dict(state_dict)
        return resnet
Exemplo n.º 29
0
def training(rank, world_size, backend, config):

    # Specific torch.distributed
    dist.init_process_group(backend,
                            init_method="tcp://0.0.0.0:2233",
                            world_size=world_size,
                            rank=rank)
    print(dist.get_rank(), ": run with config:", config, " - backend=",
          backend)

    torch.cuda.set_device(rank)

    # Data preparation
    dataset = RndDataset(nb_samples=config["nb_samples"])

    # Specific torch.distributed
    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)

    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=int(config["batch_size"] / world_size),
        num_workers=1,
        sampler=train_sampler,
    )

    # Model, criterion, optimizer setup
    model = wide_resnet50_2(num_classes=100).cuda()
    criterion = NLLLoss()
    optimizer = SGD(model.parameters(), lr=0.01)

    # Specific torch.distributed
    model = DDP(model, device_ids=[rank])

    # Training loop log param
    log_interval = config["log_interval"]

    def _train_step(batch_idx, data, target):

        data = data.cuda()
        target = target.cuda()

        optimizer.zero_grad()
        output = model(data)
        # Add a softmax layer
        probabilities = torch.nn.functional.softmax(output, dim=0)

        loss_val = criterion(probabilities, target)
        loss_val.backward()
        optimizer.step()

        if (batch_idx + 1) % (log_interval) == 0:
            print("Process {}/{} Train Epoch: {} [{}/{}]\tLoss: {}".format(
                dist.get_rank(),
                dist.get_world_size(),
                epoch,
                (batch_idx + 1) * len(data),
                len(train_sampler),
                loss_val.item(),
            ))
        return loss_val

    # Running _train_step for n_epochs
    n_epochs = 1
    for epoch in range(n_epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            _train_step(batch_idx, data, target)

    # Specific torch.distributed
    dist.destroy_process_group()
Exemplo n.º 30
0
def main():
    use_cuda = not config.no_cuda and torch.cuda.is_available()
    torch.cuda.set_device(config.gpu_id)
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': config.num_workers, 'pin_memory': True} if use_cuda else {}
    data_root="../datasets/sample_" + str(config.crop_size)
    crop_size_to_image_number = {128: 5530, 224: 1184}

    train_trans = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
                transforms.Normalize([0.48485812, 0.40377784, 0.32280155], [0.37216536, 0.349832,   0.37452201])
                ])
    eval_trans = transforms.Compose([
            transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize([0.48485812, 0.40377784, 0.32280155], [0.37216536, 0.349832,   0.37452201])
            ])


    #use for training from the ground up
    train_set_os = OpenSurfacesSmall(root_dir=data_root, n=crop_size_to_image_number[config.crop_size], split=(0,0.8), transform=train_trans)
    val_set_os = OpenSurfacesSmall(root_dir=data_root, n=crop_size_to_image_number[config.crop_size], split=(0.8,1), transform=eval_trans)

    train_loader = torch.utils.data.DataLoader(dataset=train_set_os,
                                              batch_size=config.batch_size,
                                              num_workers=config.num_workers,
                                              shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=val_set_os,
                                              batch_size=config.test_batch_size, 
                                              num_workers=config.num_workers,
                                              shuffle=False)
    #model = EfficientNet.from_pretrained(config.model, num_classes=4).to(device)   
    model = models.wide_resnet50_2(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 4)
    model.to(device)

    #set optimizer
    if config.optimizer == 'Adam':
        initial_lr = 0.001
        optimizer = optim.Adam(model.parameters(), lr=initial_lr)
    elif config.optimizer == 'SGD':
        initial_lr = 0.1
        optimizer = optim.SGD(model.parameters(), lr=initial_lr)
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)

    #optimizer = optim.SGD(model.parameters(), lr=config.lr)
    #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)
    #optimizer = optim.Adam(model.parameters(), lr=config.lr)
    criterion = nn.CrossEntropyLoss()
    '''optimizer = optim.SGD(model.parameters(), lr=config.lr,
                          momentum=config.momentum)'''
    
    # WandB  wandb.watch() automatically fetches all layer dimensions, gradients, model parameters and logs them automatically to your dashboard.
    # Using log="all" log histograms of parameter values in addition to gradients
    wandb.watch(model, log="all")

    for epoch in range(1, config.epochs + 1):
        train(config, model, device, train_loader, optimizer, epoch, criterion)
        test(config, model, device, test_loader, criterion)
        if config.optimizer == 'SGD':
            scheduler.step(test_loss)
        
    # WandB  Save the model checkpoint. This automatically saves a file to the cloud and associates it with the current run.
    torch.save(model.state_dict(), "model_" + str(config.crop_size) + "_" + config.optimizer + "_new " + ".h5")
    wandb.save("model_" + str(config.crop_size) + "_" + config.optimizer + "_new " + ".h5")