Exemplo n.º 1
0
def get_model():
    print('=> Building model..')
    if args.model == 'vgg16_bn':
        net = vgg16_bn()
    elif args.model == 'vgg16_bn_x':
        net = vgg16_bn_x()
    else:
        raise NotImplementedError
    return net.cuda() if use_cuda else net
Exemplo n.º 2
0
def get_config():
    print('=> Building model..')
    if args.model == 'vgg16_bn':
        from models.vgg import vgg16_bn
        net, ratios, policy = vgg16_bn(), vgg16_ratios, vgg16_pruning_policy
    elif args.model == 'vgg11_bn':
        from models.vgg import vgg11_bn
        net, ratios, policy = vgg11_bn(), {20: 0.5}, vgg16_pruning_policy
    else:
        print("Not support model {}".format(args.model))
        raise NotImplementedError
    return net, ratios, policy
Exemplo n.º 3
0
def get_model(_model_name, _num_classes):
    if _model_name == 'resnet34':
        return resnet34(pretrained=settings.isPretrain, num_classes=num_classes)
    elif _model_name == 'alexnet':
        return alexnet(pretrained=settings.isPretrain, num_classes=num_classes)
    elif _model_name == 'densenet121':
        return densenet121(pretrained=settings.isPretrain, num_classes=num_classes)
    elif _model_name == 'vgg16_bn':
        return vgg16_bn(pretrained=settings.isPretrain, num_classes=num_classes)
    elif _model_name == 'shufflenetv2_x1_0':
        return shufflenetv2_x1_0(pretrained=settings.isPretrain, num_classes=num_classes)
    else:
        log.logger.error("model_name error!")
        exit(-1)
Exemplo n.º 4
0
def get_model(cfg, pretrained=False, load_param_from_ours=False):

    if load_param_from_ours:
        pretrained = False

    model = None
    num_classes = cfg.num_classes
    if cfg.model == 'custom':
        from models import custom_net
        if cfg.patch_size == 64:
            model = custom_net.net_64(num_classes = num_classes)
        elif cfg.patch_size == 32:
            model = custom_net.net_32(num_classes = num_classes)
        else:
            print('Do not support present patch size %s'%cfg.patch_size)
        #model = model
    elif cfg.model == 'googlenet':
        from models import inception_v3
        model = inception_v3.inception_v3(pretrained = pretrained, num_classes = num_classes)
    elif cfg.model == 'vgg':
        from models import vgg
        if cfg.model_info == 19:
            model = vgg.vgg19_bn(pretrained = pretrained, num_classes = num_classes)
        elif cfg.model_info == 16:
            model = vgg.vgg16_bn(pretrained = pretrained, num_classes = num_classes)
    elif cfg.model == 'resnet':
        from models import resnet
        if cfg.model_info == 18:
            model = resnet.resnet18(pretrained= pretrained, num_classes = num_classes)
        elif cfg.model_info == 34:
            model = resnet.resnet34(pretrained= pretrained, num_classes = num_classes)
        elif cfg.model_info == 50:
            model = resnet.resnet50(pretrained= pretrained, num_classes = num_classes)
        elif cfg.model_info == 101:
            model = resnet.resnet101(pretrained= pretrained, num_classes = num_classes)
    if model is None:
        print('not support :' + cfg.model)
        sys.exit(-1)

    if load_param_from_ours:
        print('loading pretrained model from {0}'.format(cfg.init_model_file))
        checkpoint = torch.load(cfg.init_model_file)
        model.load_state_dict(checkpoint['model_param'])

    model.cuda()
    print('shift model to parallel!')
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu_id)
    return model
Exemplo n.º 5
0
def get_network(args,cfg):
    """ return given network
    """
    # pdb.set_trace()
    if args.net == 'lenet5':
        net = LeNet5().cuda()
    elif args.net == 'alexnet':
        net = alexnet(pretrained=args.pretrain, num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg16':
        net = vgg16(pretrained=args.pretrain, num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg13':
        net = vgg13(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg11':
        net = vgg11(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg19':
        net = vgg19(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg16_bn':
        net = vgg16_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg13_bn':
        net = vgg13_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg11_bn':
        net = vgg11_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg19_bn':
        net = vgg19_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net =='inceptionv3':
        net = inception_v3().cuda()
    # elif args.net == 'inceptionv4':
    #     net = inceptionv4().cuda()
    # elif args.net == 'inceptionresnetv2':
    #     net = inception_resnet_v2().cuda()
    elif args.net == 'resnet18':
        net = resnet18(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda(args.gpuid)
    elif args.net == 'resnet34':
        net = resnet34(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'resnet50':
        net = resnet50(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda(args.gpuid)
    elif args.net == 'resnet101':
        net = resnet101(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'resnet152':
        net = resnet152(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'squeezenet':
        net = squeezenet1_0().cuda()
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    return net
Exemplo n.º 6
0
def get_model(model, dataset, classify=True):
    """
    VGG Models
    """
    if model == 'vgg11':
        model = vgg.vgg11_bn(dataset=dataset, classify=classify)
    if model == 'vgg13':
        model = vgg.vgg13_bn(dataset=dataset, classify=classify)
    if model == 'vgg16':
        model = vgg.vgg16_bn(dataset=dataset, classify=classify)
    if model == 'vgg19':
        model = vgg.vgg19_bn(dataset=dataset, classify=classify)
    """
    CyVGG Models
    """
    if model == 'cyvgg11':
        model = cyvgg.cyvgg11_bn(dataset=dataset, classify=classify)
    if model == 'cyvgg13':
        model = cyvgg.cyvgg13_bn(dataset=dataset, classify=classify)
    if model == 'cyvgg16':
        model = cyvgg.cyvgg16_bn(dataset=dataset, classify=classify)
    if model == 'cyvgg19':
        model = cyvgg.cyvgg19_bn(dataset=dataset, classify=classify)
    """
    Resnet Models   
    """
    if model == 'resnet20':
        model = resnet.resnet20(dataset=dataset)
    if model == 'resnet32':
        model = resnet.resnet32(dataset=dataset)
    if model == 'resnet44':
        model = resnet.resnet44(dataset=dataset)
    if model == 'resnet56':
        model = resnet.resnet56(dataset=dataset)
    """
    CyResnet Models
    """
    if model == 'cyresnet20':
        model = cyresnet.cyresnet20(dataset=dataset)
    if model == 'cyresnet32':
        model = cyresnet.cyresnet32(dataset=dataset)
    if model == 'cyresnet44':
        model = cyresnet.cyresnet44(dataset=dataset)
    if model == 'cyresnet56':
        model = cyresnet.cyresnet56(dataset=dataset)

    return model
Exemplo n.º 7
0
def get_model(n_class):
    print('=> Building model {}...'.format(args.model))
    if args.model == 'vgg16_bn':
        net = vgg16_bn()
    elif args.model == 'vgg16_bn_x':
        net = vgg16_bn_x()
    else:
        raise NotImplementedError

    print('=> Loading checkpoints..')
    checkpoint = torch.load(args.checkpoint)
    if 'state_dict' in checkpoint:
        checkpoint = checkpoint['state_dict']
    checkpoint = {k.replace('module.', ''): v for k, v in checkpoint.items()}
    net.load_state_dict(checkpoint)  # remove .module

    return net
Exemplo n.º 8
0
def get_model(cfg, pretrained=True, load_param_from_folder=False):

    if load_param_from_folder:
        pretrained = False

    model = None
    num_classes = cfg.num_classes
    if cfg.model == 'googlenet':
        from models import inception_v3
        model = inception_v3.inception_v3(pretrained = pretrained, num_classes = num_classes)
    elif cfg.model == 'vgg':
        from models import vgg
        if cfg.model_info == 19:
            model = vgg.vgg19_bn(pretrained = pretrained, num_classes = num_classes)
        elif cfg.model_info == 16:
            model = vgg.vgg16_bn(pretrained = pretrained, num_classes = num_classes)
    elif cfg.model == 'resnet':
        from models import resnet
        if cfg.model_info == 18:
            model = resnet.resnet18(pretrained= pretrained, num_classes = num_classes)
        elif cfg.model_info == 34:
            model = resnet.resnet34(pretrained= pretrained, num_classes = num_classes)
        elif cfg.model_info == 50:
            model = resnet.resnet50(pretrained= pretrained, num_classes = num_classes)
        elif cfg.model_info == 101:
            model = resnet.resnet101(pretrained= pretrained, num_classes = num_classes)
    if model is None:
        print('not support :' + cfg.model)
        sys.exit(-1)

    if load_param_from_folder:
        print('loading pretrained model from {0}'.format(cfg.init_model_file))
        checkpoint = torch.load(cfg.init_model_file)
        model.load_state_dict(checkpoint['model_param'])

    print('shift model to parallel!')
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu_id)
    return model
Exemplo n.º 9
0
    def __init__(self,
                 fc_hidden1=512,
                 fc_hidden2=512,
                 drop_p=0.3,
                 CNN_embed_dim=300):

        super(ResCNNEncoder, self).__init__()

        self.fc_hidden1, self.fc_hidden2 = fc_hidden1, fc_hidden2
        self.drop_p = drop_p

        if opt.attention:
            net = vgg.vgg16_bn(progress=True)
        else:
            net = vgg_att.vgg16_bn(pretrained=True, progress=True)
        modules = list(net.children())[:-1]  # delete the last fc layer.
        self.net = nn.Sequential(*modules)

        self.fc1 = nn.Linear(net.classifier[0].in_features, fc_hidden1)
        self.bn1 = nn.BatchNorm1d(fc_hidden1, momentum=0.01)
        self.fc2 = nn.Linear(fc_hidden1, fc_hidden2)
        self.bn2 = nn.BatchNorm1d(fc_hidden2, momentum=0.01)
        self.fc3 = nn.Linear(fc_hidden2, CNN_embed_dim)
Exemplo n.º 10
0
def get_network(args, use_gpu=False):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
Exemplo n.º 11
0
def main():
    logger_init()
    dataset_type = config.DATASET
    batch_size = config.BATCH_SIZE

    # Dataset setting
    logger.info("Initialize the dataset...")
    val_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][1], \
                                 {mask_type: config.DATA_FLIST[config.MASKDATASET][mask_type][1] for mask_type in
                                  ('val',)}, \
                                 resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \
                                 random_bbox_margin=config.RANDOM_BBOX_MARGIN,
                                 random_ff_setting=config.RANDOM_FF_SETTING)
    val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1)
    # print(len(val_loader))

    ### Generate a new val data

    logger.info("Finish the dataset initialization.")

    # Define the Network Structure
    logger.info("Define the Network Structure and Losses")
    whole_model_path = 'model_logs/{}'.format(config.MODEL_RESTORE)
    nets = torch.load(whole_model_path)
    netG_state_dict, netD_state_dict = nets['netG_state_dict'], nets[
        'netD_state_dict']
    if config.NETWORK_TYPE == "l2h_unet":
        netG = InpaintRUNNet(n_in_channel=config.N_CHANNEL)
        netG.load_state_dict(netG_state_dict)

    elif config.NETWORK_TYPE == 'sa_gated':
        netG = InpaintSANet()
        load_consistent_state_dict(netG_state_dict, netG)
        # netG.load_state_dict(netG_state_dict)

    netD = InpaintSADirciminator()
    netVGG = vgg16_bn(pretrained=True)

    # netD.load_state_dict(netD_state_dict)
    logger.info("Loading pretrained models from {} ...".format(
        config.MODEL_RESTORE))

    # Define loss
    recon_loss = ReconLoss(*(config.L1_LOSS_ALPHA))
    gan_loss = SNGenLoss(config.GAN_LOSS_ALPHA)
    perc_loss = PerceptualLoss(weight=config.PERC_LOSS_ALPHA,
                               feat_extractors=netVGG.to(cuda1))
    style_loss = StyleLoss(weight=config.STYLE_LOSS_ALPHA,
                           feat_extractors=netVGG.to(cuda1))
    dis_loss = SNDisLoss()
    lr, decay = config.LEARNING_RATE, config.WEIGHT_DECAY
    optG = torch.optim.Adam(netG.parameters(), lr=lr, weight_decay=decay)
    optD = torch.optim.Adam(netD.parameters(), lr=4 * lr, weight_decay=decay)
    nets = {"netG": netG, "netD": netD, "vgg": netVGG}

    losses = {
        "GANLoss": gan_loss,
        "ReconLoss": recon_loss,
        "StyleLoss": style_loss,
        "DLoss": dis_loss,
        "PercLoss": perc_loss
    }
    opts = {
        "optG": optG,
        "optD": optD,
    }
    logger.info("Finish Define the Network Structure and Losses")

    # Start Training
    logger.info("Start Validation")

    validate(nets,
             losses,
             opts,
             val_loader,
             0,
             config.NETWORK_TYPE,
             devices=(cuda0, cuda1))
train_loader = torch.utils.data.DataLoader(dataset=train_datasets,
                                           batch_size=args.batch_size,
                                           shuffle=True)

test_datasets = torchvision.datasets.MNIST(root=args.input_path,
                                           transform=transform_test,
                                           download=True,
                                           train=False)

test_loader = torch.utils.data.DataLoader(dataset=test_datasets,
                                          batch_size=args.batch_size,
                                          shuffle=True)

# Define Network
model = vgg.vgg16_bn(pretrained=False)

# Load pre-trained weights
model_dict = model.state_dict()
pre_dict = torch.load(args.checkpoint_path)

pretrained_cla_weight_key = []
pretrained_cla_weight_value = []
model_weight_key = []
model_weight_value = []

for k, v in pre_dict.items():
    # print(k)
    pretrained_cla_weight_key.append(k)
    pretrained_cla_weight_value.append(v)
Exemplo n.º 13
0
def get_network(args):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()
    elif args.net == 'wideresnet':
        from models.wideresidual import wideresnet
        net = wideresnet()
    elif args.net == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        net = stochastic_depth_resnet18()
    elif args.net == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        net = stochastic_depth_resnet34()
    elif args.net == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        net = stochastic_depth_resnet50()
    elif args.net == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        net = stochastic_depth_resnet101()
    elif args.net == 'normal_resnet':
        from models.normal_resnet import resnet18
        net = resnet18()
    elif args.net == 'hyper_resnet':
        from models.hypernet_main import Hypernet_Main
        net = Hypernet_Main(
            encoder="resnet18",
            hypernet_params={'vqvae_dict_size': args.dict_size})
    elif args.net == 'normal_resnet_wo_bn':
        from models.normal_resnet_wo_bn import resnet18
        net = resnet18()
    elif args.net == 'hyper_resnet_wo_bn':
        from models.hypernet_main import Hypernet_Main
        net = Hypernet_Main(
            encoder="resnet18_wobn",
            hypernet_params={'vqvae_dict_size': args.dict_size})
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()

    return net
Exemplo n.º 14
0
def get_model(args):

    assert args.model in [
        'derpnet', 'alexnet', 'resnet', 'vgg', 'vgg_attn', 'inception'
    ]

    if args.model == 'alexnet':
        model = alexnet.alexnet(pretrained=args.pretrained,
                                n_channels=args.n_channels,
                                num_classes=args.n_classes)
    elif args.model == 'inception':
        model = inception.inception_v3(pretrained=args.pretrained,
                                       aux_logits=False,
                                       progress=True,
                                       num_classes=args.n_classes)
    elif args.model == 'vgg':
        assert args.model_depth in [11, 13, 16, 19]

        if args.model_depth == 11:
            model = vgg.vgg11_bn(pretrained=args.pretrained,
                                 progress=True,
                                 num_classes=args.n_classes)
        if args.model_depth == 13:
            model = vgg.vgg13_bn(pretrained=args.pretrained,
                                 progress=True,
                                 num_classes=args.n_classes)
        if args.model_depth == 16:
            model = vgg.vgg16_bn(pretrained=args.pretrained,
                                 progress=True,
                                 num_classes=args.n_classes)
        if args.model_depth == 19:
            model = vgg.vgg19(pretrained=args.pretrained,
                              progress=True,
                              num_classes=args.n_classes)

    elif args.model == 'vgg_attn':
        assert args.model_depth in [11, 13, 16, 19]

        if args.model_depth == 11:
            model = vgg_attn.vgg11_bn(pretrained=args.pretrained,
                                      progress=True,
                                      num_classes=args.n_classes)
        if args.model_depth == 13:
            model = vgg_attn.vgg11_bn(pretrained=args.pretrained,
                                      progress=True,
                                      num_classes=args.n_classes)
        if args.model_depth == 16:
            model = vgg_attn.vgg11_bn(pretrained=args.pretrained,
                                      progress=True,
                                      num_classes=args.n_classes)
        if args.model_depth == 19:
            model = vgg_attn.vgg11_bn(pretrained=args.pretrained,
                                      progress=True,
                                      num_classes=args.n_classes)

    elif args.model == 'derpnet':
        model = derp_net.Net(n_channels=args.n_channels,
                             num_classes=args.n_classes)

    elif args.model == 'resnet':
        assert args.model_depth in [10, 18, 34, 50, 101, 152, 200]

        if args.model_depth == 10:
            model = resnet.resnet10(pretrained=args.pretrained,
                                    num_classes=args.n_classes)
        elif args.model_depth == 18:
            model = resnet.resnet18(pretrained=args.pretrained,
                                    num_classes=args.n_classes)
        elif args.model_depth == 34:
            model = resnet.resnet34(pretrained=args.pretrained,
                                    num_classes=args.n_classes)
        elif args.model_depth == 50:
            model = resnet.resnet50(pretrained=args.pretrained,
                                    num_classes=args.n_classes)
        elif args.model_depth == 101:
            model = resnet.resnet101(pretrained=args.pretrained,
                                     num_classes=args.n_classes)
        elif args.model_depth == 152:
            model = resnet.resnet152(pretrained=args.pretrained,
                                     num_classes=args.n_classes)
        elif args.model_depth == 200:
            model = resnet.resnet200(pretrained=args.pretrained,
                                     num_classes=args.n_classes)

    if args.pretrained and args.pretrain_path and not args.model == 'alexnet' and not args.model == 'vgg' and not args.model == 'resnet':

        print('loading pretrained model {}'.format(args.pretrain_path))
        pretrain = torch.load(args.pretrain_path)
        assert args.arch == pretrain['arch']

        # here all the magic happens: need to pick the parameters which will be adjusted during training
        # the rest of the params will be frozen
        pretrain_dict = {
            key[7:]: value
            for key, value in pretrain['state_dict'].items()
            if key[7:9] != 'fc'
        }
        from collections import OrderedDict
        pretrain_dict = OrderedDict(pretrain_dict)

        # https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance
        import types
        model.load_state_dict = types.MethodType(load_my_state_dict, model)

        old_dict = copy.deepcopy(
            model.state_dict())  # normal copy() just gives a reference
        model.load_state_dict(pretrain_dict)
        new_dict = model.state_dict()

        num_features = model.fc.in_features
        if args.model == 'densenet':
            model.classifier = nn.Linear(num_features, args.n_classes)
        else:
            #model.fc = nn.Sequential(nn.Linear(num_features, 1028), nn.ReLU(), nn.Dropout(0.5), nn.Linear(1028, args.n_finetune_classes))
            model.fc = nn.Linear(num_features, args.n_classes)

        # parameters = get_fine_tuning_parameters(model, args.ft_begin_index)
        parameters = model.parameters()  # fine-tunining EVERYTHIIIIIANG
        # parameters = model.fc.parameters()  # fine-tunining ONLY FC layer
        return model, parameters

    return model, model.parameters()
Exemplo n.º 15
0
def get_network(args, use_gpu=True):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
Exemplo n.º 16
0
def get_model(args, model_path=None):
    """

    :param args: super arguments
    :param model_path: if not None, load already trained model parameters.
    :return: model
    """
    if args.scratch:  # train model from scratch
        pretrained = False
        model_dir = None
        print("=> Loading model '{}' from scratch...".format(args.model))
    else:  # train model with pretrained model
        pretrained = True
        model_dir = os.path.join(args.root_path, args.pretrained_models_path)
        print("=> Loading pretrained model '{}'...".format(args.model))

    if args.model.startswith('resnet'):

        if args.model == 'resnet18':
            model = resnet18(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet34':
            model = resnet34(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet50':
            model = resnet50(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet101':
            model = resnet101(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet152':
            model = resnet152(pretrained=pretrained, model_dir=model_dir)

        model.fc = nn.Linear(model.fc.in_features, args.num_classes)

    elif args.model.startswith('vgg'):
        if args.model == 'vgg11':
            model = vgg11(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg11_bn':
            model = vgg11_bn(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg13':
            model = vgg13(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg13_bn':
            model = vgg13_bn(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg16':
            model = vgg16(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg16_bn':
            model = vgg16_bn(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg19':
            model = vgg19(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg19_bn':
            model = vgg19_bn(pretrained=pretrained, model_dir=model_dir)

        model.classifier[6] = nn.Linear(model.classifier[6].in_features, args.num_classes)

    elif args.model == 'alexnet':
        model = alexnet(pretrained=pretrained, model_dir=model_dir)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, args.num_classes)

    # Load already trained model parameters and go on training
    if model_path is not None:
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model'])

    return model
Exemplo n.º 17
0
def get_network(netname, use_gpu=True):
    """ return given network
    """

    if netname == 'vgg16':
        from models.vgg import vgg16_bn  #!
        net = vgg16_bn()
    elif netname == 'vgg16_cbn':
        from models.vgg_nobn import vgg16_cbn  #!
        net = vgg16_cbn()
    elif netname == 'vgg11':
        from models.vgg import vgg11_bn  #!
        net = vgg11_bn()
    elif netname == 'vgg11_cbn':
        from models.vgg_nobn import vgg11_cbn  #!
        net = vgg11_cbn()
    elif netname == 'vgg11_nobn':
        from models.vgg_nobn import vgg11_nobn  #!
        net = vgg11_nobn()
    elif netname == 'vgg16_nobn':
        from models.vgg_nobn import vgg16_nobn  #!
        net = vgg16_nobn()
    elif netname == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif netname == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif netname == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif netname == 'resnet18_nobn':
        from models.resnet_nobn import resnet18_nobn
        net = resnet18_nobn()
    elif netname == 'resnet18_fixup':
        from models.resnet_fixup import resnet18
        net = resnet18()
    elif netname == 'resnet50_fixup':
        from models.resnet_fixup import resnet50
        net = resnet50()
    elif netname == 'resnet18_cbn':
        from models.resnet_nobn import resnet18_cbn
        net = resnet18_cbn()
    elif netname == 'resnet50_cbn':
        from models.resnet_nobn import resnet50_cbn
        net = resnet50_cbn()
    elif netname == 'resnet50_nobn':
        from models.resnet_nobn import resnet50_nobn
        net = resnet50_nobn()
    elif netname == 'resnet101_cbn':
        from models.resnet_nobn import resnet101_cbn
        net = resnet101_cbn()
    elif netname == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif netname == 'densenet121_cbn':
        from models.densenet_nobn import densenet121
        net = densenet121()
    elif netname == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif netname == 'shufflenetv2_cbn':
        from models.shufflenetv2_nobn import shufflenetv2_cbn
        net = shufflenetv2_cbn()
    elif netname == 'shufflenetv2_nobn':
        from models.shufflenetv2_nobn import shufflenetv2_nobn
        net = shufflenetv2_nobn()
    elif netname == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif netname == 'squeezenet_nobn':
        from models.squeezenet_nobn import squeezenet_nobn
        net = squeezenet_nobn()
    elif netname == 'squeezenet_cbn':
        from models.squeezenet_nobn import squeezenet_cbn
        net = squeezenet_cbn()
    elif netname == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif netname == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif netname == 'seresnet18_cbn':
        from models.senet_nobn import seresnet18
        net = seresnet18()
    elif netname == 'seresnet50_cbn':
        from models.senet_nobn import seresnet50
        net = seresnet50()
    elif netname == 'fixup_cbn':
        from models.fixup_resnet_cifar import fixup_resnet56
        net = fixup_resnet56(cbn=True)
    elif netname == 'fixup':
        from models.fixup_resnet_cifar import fixup_resnet56
        net = fixup_resnet56()
    elif netname == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif netname == 'mobilenetv2_cbn':
        from models.mobilenetv2_nobn import mobilenetv2
        net = mobilenetv2()
    else:
        print(netname)
        print('the network name you have entered is not supported yet')
        sys.exit()

    if use_gpu:
        #  net = torch.nn.parallel.DataParallel(net)
        net = net.cuda()

    return net
Exemplo n.º 18
0
                                            transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=2)
    num_classes = 100
else:
    raise "only support dataset CIFAR10 or CIFAR100"

if args.model == "lenet":
    net = lenet.LeNet(num_classes=num_classes)

elif args.model == "vgg16":
    net = vgg.vgg16(num_classes=num_classes, pretrained=args.pretrain)
elif args.model == "vgg16_bn":
    net = vgg.vgg16_bn(num_classes=num_classes, pretrained=args.pretrain)

elif args.model == "resnet18":
    net = resnet.resnet18(num_classes=num_classes, pretrained=args.pretrain)
elif args.model == "resnet34":
    net = resnet.resnet18(num_classes=num_classes, pretrained=args.pretrain)
elif args.model == "resnet50":
    net = resnet.resnet50(num_classes=num_classes, pretrained=args.pretrain)

elif args.model == "resnetv2_18":
    net = resnet_v2.resnet18(num_classes=num_classes, pretrained=args.pretrain)
elif args.model == "resnetv2_34":
    net = resnet_v2.resnet18(num_classes=num_classes, pretrained=args.pretrain)
elif args.model == "resnetv2_50":
    net = resnet_v2.resnet50(num_classes=num_classes, pretrained=args.pretrain)
Exemplo n.º 19
0
def get_network(key, num_cls=2, use_gpu=False):
    """ return given network
    """

    if key == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn(num_cls)
    elif key == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn(num_cls)
    elif key == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn(num_cls)
    elif key == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn(num_cls)
    elif key == 'resnext':
        print('we will continue')
    elif key == 'efficientNetb0':
        from models.torchefficient import make_model
        net = make_model(key, num_cls)
    elif key == 'efficientNetb1':
        from models.torchefficient import make_model
        net = make_model(key, num_cls)
    elif key == 'efficientNetb2':
        from models.torchefficient import make_model
        net = make_model(key, num_cls)
    elif key == 'efficientNetb3':
        from models.torchefficient import make_model
        net = make_model(key, num_cls)
    elif key == 'efficientNetb4':
        from models.torchefficient import make_model
        net = make_model(key, num_cls)
    elif key == 'efficientNetb5':
        from models.torchefficient import make_model
        net = make_model(key, num_cls)
    elif key == 'efficientNetb6':
        from models.torchefficient import make_model
        net = make_model(key, num_cls)
    elif key == 'efficientNetb7':
        from models.torchefficient import make_model
        net = make_model(key, num_cls)
    elif key == 'resnext50_32x8d':
        from models.resnext import make_model
        net = make_model(key)
    elif key == 'resnext101_32x8d':
        from models.resnext import make_model
        net = make_model(key)
    elif key == 'resnet50':
        from models.resnet import make_model
        net = make_model(key)
    elif key == 'resnet18':
        from models.resnet import make_model
        net = make_model(key)
    elif key == 'resnet34':
        from models.resnet import make_model
        net = make_model(key)
    elif key == 'resnet101':
        from models.resnet import make_model
        net = make_model(key)
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()
    if use_gpu:
        net = net.cuda()
    return net
Exemplo n.º 20
0
def two_class_test(imgPath, xmlPath, weight_path, target_name1, target_name2, netname):
    """
    重新二分类
    Args:
        imgPath: 测试集图片存储路径
        xmlPath: 一阶段测试结果路径
        weight_path: 测试权重路径
        target_name1:细胞名1
        target_name2:细胞名2
        netname:测试网络类型
    Returns:

    """
    if netname == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()

    transform_test = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(cell_train_mean, cell_train_std)
    ])
    count = 0
    change = 0
    net.load_state_dict(torch.load(weight_path))
    net.eval()
    for subdir in os.listdir(xmlPath):
        xmllist = os.listdir(xmlPath+subdir)
        isUpdated = False
        for xmlfile in xmllist:
            srcfile = xmlPath + subdir + "/" + xmlfile              # xml文件名
            image_pre, ext = os.path.splitext(xmlfile)
            imgfile = imgPath + subdir + "/" + image_pre + ".jpg"   # 原图文件名
            img = cv2.imread(imgfile)     # 原图图片
            DOMTree = xml.dom.minidom.parse(srcfile) # 打开xml文档
            collection = DOMTree.documentElement # 得到文档元素对象
            objectlist = collection.getElementsByTagName("object")  # 得到标签名为object的信息
            for objects in objectlist:
                namelist = objects.getElementsByTagName('name')
                name = namelist[0].childNodes[0].data
                if name != target_name1 and name != target_name2: # 对原红细胞和早幼红细胞做二分类
                    continue
                pos0 = cellName.index(target_name1)
                pos1 = cellName.index(target_name2)
                isUpdated = True  # xml需要更新
                bndbox = objects.getElementsByTagName('bndbox')
                for box in bndbox:
                    x1_list = box.getElementsByTagName('xmin')
                    x1 = int(x1_list[0].childNodes[0].data)
                    y1_list = box.getElementsByTagName('ymin')
                    y1 = int(y1_list[0].childNodes[0].data)
                    x2_list = box.getElementsByTagName('xmax')  # 注意坐标,看是否需要转换
                    x2 = int(x2_list[0].childNodes[0].data)
                    y2_list = box.getElementsByTagName('ymax')
                    y2 = int(y2_list[0].childNodes[0].data)
                try:
                    subimg = img[y1:y2, x1:x2, :]  # 细胞子图
                    subimg = transform_test(subimg)
                    subimg = subimg.unsqueeze(0)
                    # print("sub = ", subimg.shape)
                    output = net(subimg)        ####################################### 输入细胞图片,预测类别
                    # softmax = nn.Softmax()
                    # out = softmax(output)
                    # prob = out.detach().numpy()
                    # with open("/home/steadysjtu/classification/prob.txt", "a") as f:
                    #     f.write(str(prob[0][0])+' '+str(prob[0][1])+'\n')

                    _, preds = output.max(1)
                    preds = preds.item()
                    # print(preds)
                    if preds == 0:
                        preds = pos0
                    else:
                        preds = pos1
                    newname = cellName[preds]      # 根据类别标签输出细胞名称
                    # print(newname)
                    namelist[0].childNodes[0].data = newname
                    count += 1
                    if name != newname:
                        change += 1
                        print("change=", change,',newname = ', newname)
                except:
                    print("error")
            if isUpdated:
                writeXml(DOMTree, srcfile)  # 更新xml
    print("count = ", count)
Exemplo n.º 21
0
def main():
    global args, best_prec1

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    if args.arch == 'vgg':
        model = vgg16_bn()
    elif args.arch == 'alexnet':
        model = alexnet(
            num_classes=10) if args.dataset == 'cifar10' else alexnet(
                num_classes=100)
    elif args.arch == 'wide_resnet':
        if args.dataset == 'cifar10':
            model = wide_WResNet(num_classes=10, depth=16, dataset='cifar10')
        else:
            model = wide_WResNet(num_classes=100, depth=16, dataset='cifar100')
    elif args.arch == 'resnet':
        if args.dataset == 'cifar10':
            model = resnet(num_classes=10, dataset='cifar10')
        else:
            model = resnet(num_classes=100, dataset='cifar100')

    model.cuda()

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if args.half:
        model.half()
        criterion.half()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        cubic_train(model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=os.path.join(args.save_dir,
                                  'checkpoint_{}.tar'.format(epoch)))
                                                      transform=transform)
    testloader_normalized = torch.utils.data.DataLoader(testset_normalized,
                                                        batch_size=1,
                                                        shuffle=False,
                                                        num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data',
                                           train=False,
                                           download=True,
                                           transform=transforms.ToTensor())
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2)
elif net == 'cif100':
    model = vgg16_bn()
    model.load_state_dict(torch.load('./models/vgg_cif100.pth'))
    print('Load CIFAR-100 test set')
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                             (0.2675, 0.2565, 0.2761))
    ])

    testset_normalized = torchvision.datasets.CIFAR100(root='./data',
                                                       train=False,
                                                       download=True,
                                                       transform=transform)
    testloader_normalized = torch.utils.data.DataLoader(testset_normalized,
                                                        batch_size=1,
                                                        shuffle=False,
Exemplo n.º 23
0
def main():
    global args, best_err1
    args = parser.parse_args()

    # TensorBoard configure
    if args.tensorboard:
        configure('%s_checkpoints/%s'%(args.dataset, args.expname))

    # CUDA
    os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_ids)
    if torch.cuda.is_available():
        cudnn.benchmark = True  # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        kwargs = {'num_workers': 2, 'pin_memory': True}
    else:
        kwargs = {'num_workers': 2}

    # Data loading code
    if args.dataset == 'cifar10':
        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                         std=[0.2023, 0.1994, 0.2010])
    elif args.dataset == 'cifar100':
        normalize = transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
                                         std=[0.2634, 0.2528, 0.2719])
    elif args.dataset == 'cub':
        normalize = transforms.Normalize(mean=[0.4862, 0.4973, 0.4293],
                                         std=[0.2230, 0.2185, 0.2472])
    elif args.dataset == 'webvision':
        normalize = transforms.Normalize(mean=[0.49274242, 0.46481857, 0.41779366],
                                         std=[0.26831809, 0.26145372, 0.27042758])
    else:
        raise Exception('Unknown dataset: {}'.format(args.dataset))

    # Transforms
    if args.augment:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(args.train_image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(args.train_image_size),
            transforms.ToTensor(),
            normalize,
        ])
    val_transform = transforms.Compose([
        transforms.Resize(args.test_image_size),
        transforms.CenterCrop(args.test_crop_image_size),
        transforms.ToTensor(),
        normalize
    ])

    # Datasets
    num_classes = 10    # default 10 classes
    if args.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10('./data/', train=True, download=True, transform=train_transform)
        val_dataset = datasets.CIFAR10('./data/', train=False, download=True, transform=val_transform)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100('./data/', train=True, download=True, transform=train_transform)
        val_dataset = datasets.CIFAR100('./data/', train=False, download=True, transform=val_transform)
        num_classes = 100
    elif args.dataset == 'cub':
        train_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/DuAngAng/datasets/CUB-200-2011/train/',
                                             transform=train_transform)
        val_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/DuAngAng/datasets/CUB-200-2011/test/',
                                           transform=val_transform)
        num_classes = 200
    elif args.dataset == 'webvision':
        train_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/LiuJing/WebVision/info/train',
                                             transform=train_transform)
        val_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/LiuJing/WebVision/info/val',
                                           transform=val_transform)
        num_classes = 1000
    else:
        raise Exception('Unknown dataset: {}'.format(args.dataset))

    # Data Loader
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, **kwargs)

    # Create model
    if args.model == 'AlexNet':
        model = alexnet(pretrained=False, num_classes=num_classes)
    elif args.model == 'VGG':
        use_batch_normalization = True  # default use Batch Normalization
        if use_batch_normalization:
            if args.depth == 11:
                model = vgg11_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 13:
                model = vgg13_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 16:
                model = vgg16_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 19:
                model = vgg19_bn(pretrained=False, num_classes=num_classes)
            else:
                raise Exception('Unsupport VGG detph: {}, optional depths: 11, 13, 16 or 19'.format(args.depth))
        else:
            if args.depth == 11:
                model = vgg11(pretrained=False, num_classes=num_classes)
            elif args.depth == 13:
                model = vgg13(pretrained=False, num_classes=num_classes)
            elif args.depth == 16:
                model = vgg16(pretrained=False, num_classes=num_classes)
            elif args.depth == 19:
                model = vgg19(pretrained=False, num_classes=num_classes)
            else:
                raise Exception('Unsupport VGG detph: {}, optional depths: 11, 13, 16 or 19'.format(args.depth))
    elif args.model == 'Inception':
        model = inception_v3(pretrained=False, num_classes=num_classes)
    elif args.model == 'ResNet':
        if args.depth == 18:
            model = resnet18(pretrained=False, num_classes=num_classes)
        elif args.depth == 34:
            model = resnet34(pretrained=False, num_classes=num_classes)
        elif args.depth == 50:
            model = resnet50(pretrained=False, num_classes=num_classes)
        elif args.depth == 101:
            model = resnet101(pretrained=False, num_classes=num_classes)
        elif args.depth == 152:
            model = resnet152(pretrained=False, num_classes=num_classes)
        else:
            raise Exception('Unsupport ResNet detph: {}, optional depths: 18, 34, 50, 101 or 152'.format(args.depth))
    elif args.model == 'MPN-COV-ResNet':
        if args.depth == 18:
            model = mpn_cov_resnet18(pretrained=False, num_classes=num_classes)
        elif args.depth == 34:
            model = mpn_cov_resnet34(pretrained=False, num_classes=num_classes)
        elif args.depth == 50:
            model = mpn_cov_resnet50(pretrained=False, num_classes=num_classes)
        elif args.depth == 101:
            model = mpn_cov_resnet101(pretrained=False, num_classes=num_classes)
        elif args.depth == 152:
            model = mpn_cov_resnet152(pretrained=False, num_classes=num_classes)
        else:
            raise Exception('Unsupport MPN-COV-ResNet detph: {}, optional depths: 18, 34, 50, 101 or 152'.format(args.depth))
    else:
        raise Exception('Unsupport model'.format(args.model))

    # Get the number of model parameters
    print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))

    if torch.cuda.is_available():
        model = model.cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("==> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_err1 = checkpoint['best_err1']
            model.load_state_dict(checkpoint['state_dict'])
            print("==> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("==> no checkpoint found at '{}'".format(args.resume))

    print(model)

    # Define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    if torch.cuda.is_available():
        criterion = criterion.cuda()
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # Train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # Evaluate on validation set
        err1 = validate(val_loader, model, criterion, epoch)

        # Remember best err1 and save checkpoint
        is_best = (err1 <= best_err1)
        best_err1 = min(err1, best_err1)
        print("Current best accuracy (error):", best_err1)
        save_checkpoint({
            'epoch': epoch+1,
            'state_dict': model.state_dict(),
            'best_err1': best_err1,
        }, is_best)

    print("Best accuracy (error):", best_err1)
Exemplo n.º 24
0
def train_cla(path):
    print(f"Train numbers:{len(train_datasets)}")
    print(f"Test numbers:{len(test_datasets)}")

    model = vgg.vgg16_bn(pretrained=False)
    # model.load_state_dict(torch.load(path))
    model.to(device)
    # summary(model, (3, 32, 32))
    # print(model)
    # criterion
    criterion = nn.CrossEntropyLoss().to(device)

    length = len(train_loader)  # iter数量
    best_acc = 0  # 初始化best test accuracy
    best_acc_epoch = 0
    print("Start Training, Resnet-18!")
    with open(args.acc_file_path, "w") as f1:
        with open(args.log_file_path, "w") as f2:
            for epoch in range(0, args.epochs):
                if epoch + 1 <= 100:
                    args.lr = 0.1
                elif 100 < epoch + 1 <= 200:
                    args.lr = 0.01
                elif 200 < epoch + 1 <= 250:
                    args.lr = 0.001
                else:
                    args.lr = 0.0001

                # Optimization
                optimizer = optim.SGD(model.parameters(),
                                      lr=args.lr,
                                      momentum=0.9,
                                      weight_decay=5e-4)
                # optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

                print("Epoch: %d" % (epoch + 1))
                sum_loss = 0.0
                correct = 0.0
                total = 0.0
                for i, data in enumerate(train_loader, 0):
                    start = time.time()

                    # 准备数据
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)
                    model.to(device)
                    model.train()
                    optimizer.zero_grad()

                    # forward + backward
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()

                    # 每训练1个batch打印一次loss和准确率
                    sum_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += predicted.eq(labels.data).cpu().sum().item()
                    # print(100.* correct / total)

                    end = time.time()

                    print(
                        "[Epoch:%d/%d] | [Batch:%d/%d] | Loss: %.03f | Acc: %.2f%% | Lr: %.04f | Time: %.03fs"
                        % (epoch + 1, args.epochs, i + 1, length, sum_loss /
                           (i + 1), correct / total * 100, args.lr,
                           (end - start)))
                    f2.write(
                        "[Epoch:%d/%d] | [Batch:%d/%d] | Loss: %.03f | Acc: %.2f%% | Lr: %.4f | Time: %.3fs"
                        % (epoch + 1, args.epochs, i + 1, length, sum_loss /
                           (i + 1), correct / total * 100, args.lr,
                           (end - start)))
                    f2.write("\n")
                    f2.flush()
                # 每训练完一个epoch测试一下准确率
                if (epoch + 1) % 50 == 0:
                    print("Waiting for Testing!")
                    with torch.no_grad():
                        correct = 0
                        total = 0
                        for data in test_loader:
                            model.eval()
                            images, labels = data
                            images, labels = images.to(device), labels.to(
                                device)
                            model.to(device)
                            outputs = model(images)
                            # 取得分最高的那个类 (outputs.data的索引号)
                            _, predicted = torch.max(outputs.data, 1)
                            total += labels.size(0)
                            correct += (predicted == labels).sum().item()
                        print("Test Set Accuracy:%.2f%%" %
                              (correct / total * 100))
                        acc = correct / total * 100
                        # 保存测试集准确率至acc.txt文件中
                        f1.write("Epoch=%03d,Accuracy= %.2f%%" %
                                 (epoch + 1, acc))
                        f1.write("\n")
                        f1.flush()
                        print("Saving model!")
                        torch.save(
                            model.state_dict(),
                            "%s/model_%d.pth" % (args.model_path, epoch + 1))
                        print("Model saved!")
                        # 记录最佳测试分类准确率并写入best_acc.txt文件中并将准确率达标的模型保存
                        if acc > best_acc:
                            # if epoch != 49:
                            #    os.remove(args.model_path + "model_" + str(best_acc_epoch) + ".pth")
                            # print("Saving model!")
                            # torch.save(model.state_dict(), "%s/model_%d.pth" % (args.model_path, epoch + 1))
                            # print("Model saved!")
                            f3 = open(args.best_acc_file_path, "w")
                            f3.write("Epoch=%d,best_acc= %.2f%%" %
                                     (epoch + 1, acc))
                            f3.close()
                            best_acc = acc
                            best_acc_epoch = epoch + 1
            print(
                "Training Finished, TotalEpoch = %d, Best Accuracy(Epoch) = %.2f%%(%d)"
                % (args.epochs, best_acc, best_acc_epoch))

    return best_acc_epoch
def FGSM(best_cla_model_path, device_used):
    device = torch.device(device_used if torch.cuda.is_available() else "cpu")

    parser = argparse.ArgumentParser("Adversarial Examples")
    parser.add_argument(
        "--input_path",
        type=str,
        default="C:/Users/WenqingLiu/cifar/cifar10/cifar-10-batches-py/",
        help="data set dir path")
    parser.add_argument(
        "--output_path_train",
        type=str,
        default=
        "D:/python_workplace/resnet-AE/checkpoint/Joint_Training/VGG/cifar10/data/train/train.pkl",
        help="Output directory with train images.")
    parser.add_argument(
        "--output_path_test",
        type=str,
        default=
        "D:/python_workplace/resnet-AE/checkpoint/Joint_Training/VGG/cifar10/data/test/test.pkl",
        help="Output directory with test images.")
    parser.add_argument("--epsilon", type=float, default=0.1, help="Epsilon")
    parser.add_argument("--L_F", type=int, default=5, help="L_F")
    parser.add_argument("--image_size",
                        type=int,
                        default=32,
                        help="Width of each input images.")
    parser.add_argument("--batch_size",
                        type=int,
                        default=200,
                        help="How many images process at one time.")
    parser.add_argument("--num_classes",
                        type=int,
                        default=10,
                        help="num classes")
    parser.add_argument(
        "--output_path_acc",
        type=str,
        default=
        "D:/python_workplace/resnet-AE/checkpoint/Joint_Training/VGG/cifar10/data/acc.txt",
        help="Output directory with acc file.")

    args = parser.parse_args()

    # Transform Init
    transform_train = transforms.Compose([
        # transforms.Resize(32),  # 将图像转化为32 * 32
        # transforms.RandomCrop(32, padding=4),  # 先四周填充0,在吧图像随机裁剪成32*32
        # transforms.ColorJitter(brightness=1, contrast=2, saturation=3, hue=0),  # 给图像增加一些随机的光照
        # transforms.RandomHorizontalFlip(),  # 图像一半的概率翻转,一半的概率不翻转
        transforms.ToTensor(),
        # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        # transforms.Resize(32),  # 将图像转化为32 * 32
        # transforms.RandomCrop(32, padding=4),  # 先四周填充0,在吧图像随机裁剪成32*32
        # transforms.ColorJitter(brightness=1, contrast=2, saturation=3, hue=0),  # 给图像增加一些随机的光照
        # transforms.RandomHorizontalFlip(),  # 图像一半的概率翻转,一半的概率不翻转
        transforms.ToTensor(),
        # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Data Parse
    train_datasets = torchvision.datasets.CIFAR10(root=args.input_path,
                                                  transform=transform_train,
                                                  download=True,
                                                  train=True)

    train_loader = torch.utils.data.DataLoader(dataset=train_datasets,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    test_datasets = torchvision.datasets.CIFAR10(root=args.input_path,
                                                 transform=transform_test,
                                                 download=True,
                                                 train=False)

    test_loader = torch.utils.data.DataLoader(dataset=test_datasets,
                                              batch_size=args.batch_size,
                                              shuffle=True)

    # Define Network
    model = vgg.vgg16_bn(pretrained=False)

    # Load pre-trained weights
    # model.load_state_dict(torch.load(best_cla_model_path))
    # model.to(device)

    model_dict = model.state_dict()
    pretrained_ae_model = torch.load(best_cla_model_path)
    model_key = []
    model_value = []
    pretrained_ae_key = []
    pretrained_ae_value = []
    for k, v in model_dict.items():
        # print(k)
        model_key.append(k)
        model_value.append(v)
    for k, v in pretrained_ae_model.items():
        # print(k)
        pretrained_ae_key.append(k)
        pretrained_ae_value.append(v)
    new_dict = {}

    for i in range(len(model_dict)):
        new_dict[model_key[i]] = pretrained_ae_value[i]

    model_dict.update(new_dict)
    model.load_state_dict(model_dict)
    model.to(device)

    print("Weights Loaded!")

    # criterion
    criterion = nn.CrossEntropyLoss().to(device)

    # adversarial examples of train set
    noises_train = []
    y_preds_train = []
    y_preds_train_adversarial = []
    y_correct_train = 0
    y_correct_train_adversarial = 0
    images_clean_train = []
    images_adv_train = []
    y_trues_clean_train = []

    for data in train_loader:
        x_input, y_true = data
        x_input, y_true = x_input.to(device), y_true.to(device)
        x_input.requires_grad_()

        # Forward pass
        model.eval()
        outputs = model(x_input)
        loss = criterion(outputs, y_true)
        # print(y_true.cpu().data.numpy())
        loss.backward()  # obtain gradients on x

        # Classification before Adv
        _, y_pred = torch.max(outputs.data, 1)
        y_correct_train += y_pred.eq(y_true.data).cpu().sum().item()

        # Generate Adversarial Image
        # Add perturbation
        epsilon = args.epsilon
        x_grad = torch.sign(x_input.grad.data)
        x_adversarial = torch.clamp(x_input.data + epsilon * x_grad, 0,
                                    1).to(device)
        # x_adversarial = (x_input.data + epsilon * x_grad).to(device)
        image_adversarial_train = x_adversarial
        noise_train = x_adversarial - x_input

        # image_origin_train = x_input.cpu().data.numpy() * 255
        # # print(x_input.cpu().data.numpy().shape)
        # image_origin_train = np.rint(image_origin_train).astype(np.int)
        #
        # image_adversarial_train = x_adversarial.cpu().data.numpy() * 255
        # image_adversarial_train = np.rint(image_adversarial_train).astype(np.int)
        #
        # noise_train = image_adversarial_train - image_origin_train
        # # noise_train = np.where(noise_train >= args.L_F, args.L_F, noise_train)
        # # noise_train = np.where(noise_train <= -args.L_F, args.L_F, noise_train)
        #
        # image_adversarial_train = noise_train + image_origin_train
        #
        # noise_train = noise_train / 255
        # image_adversarial_train = image_adversarial_train / 255

        # Classification after optimization
        # outputs_adversarial = model(Variable(torch.from_numpy(image_adversarial_train).type(torch.FloatTensor).to(device)))
        outputs_adversarial = model(image_adversarial_train)
        _, y_pred_adversarial = torch.max(outputs_adversarial.data, 1)
        y_correct_train_adversarial += y_pred_adversarial.eq(
            y_true.data).cpu().sum().item()

        y_preds_train.extend(list(y_pred.cpu().data.numpy()))
        y_preds_train_adversarial.extend(
            list(y_pred_adversarial.cpu().data.numpy()))
        # noises_train.extend(list(noise_train))
        # images_adv_train.extend(list(image_adversarial_train)
        noises_train.extend(list(noise_train.cpu().data.numpy()))
        images_adv_train.extend(
            list(image_adversarial_train.cpu().data.numpy()))
        images_clean_train.extend(list(x_input.cpu().data.numpy()))
        y_trues_clean_train.extend(list(y_true.cpu().data.numpy()))

        # print(x_input.data.cpu().numpy())
        # print(noises_train)

    # adversarial examples of test set
    noises_test = []
    y_preds_test = []
    y_preds_test_adversarial = []
    y_correct_test = 0
    y_correct_test_adversarial = 0
    images_adv_test = []
    images_clean_test = []
    y_trues_clean_test = []

    for data in test_loader:
        x_input, y_true = data
        x_input, y_true = x_input.to(device), y_true.to(device)
        x_input.requires_grad_()

        # Forward pass
        model.eval()
        outputs = model(x_input)
        loss = criterion(outputs, y_true)
        # print(y_true.cpu().data.numpy())
        loss.backward()  # obtain gradients on x

        # Classification before Adv
        _, y_pred = torch.max(outputs.data, 1)
        y_correct_test += y_pred.eq(y_true.data).cpu().sum().item()

        # Generate Adversarial Image
        # Add perturbation
        epsilon = args.epsilon
        x_grad = torch.sign(x_input.grad.data)
        x_adversarial = torch.clamp(x_input.data + epsilon * x_grad, 0,
                                    1).to(device)
        # x_adversarial = (x_input.data + epsilon * x_grad).to(device)
        image_adversarial_test = x_adversarial
        noise_test = x_adversarial - x_input

        # image_origin_test = x_input.cpu().data.numpy() * 255
        # image_origin_test = np.rint(image_origin_test).astype(np.int)
        #
        # image_adversarial_test = x_adversarial.cpu().data.numpy() * 255
        # image_adversarial_test = np.rint(image_adversarial_test).astype(np.int)
        #
        # noise_test = image_adversarial_test - image_origin_test
        # # noise_test = np.where(noise_test >= args.L_F, args.L_F, noise_test)
        # # noise_test = np.where(noise_test <= -args.L_F, args.L_F, noise_test)
        #
        # image_adversarial_test = noise_test + image_origin_test
        #
        # image_adversarial_test = image_adversarial_test / 255
        # noise_test = noise_test / 255

        # Classification after optimization
        # outputs_adversarial = model(Variable(torch.from_numpy(image_adversarial_test).type(torch.FloatTensor).to(device)))
        outputs_adversarial = model(image_adversarial_test)
        _, y_pred_adversarial = torch.max(outputs_adversarial.data, 1)
        y_correct_test_adversarial += y_pred_adversarial.eq(
            y_true.data).cpu().sum().item()

        y_preds_test.extend(list(y_pred.cpu().data.numpy()))
        y_preds_test_adversarial.extend(
            list(y_pred_adversarial.cpu().data.numpy()))
        # noises_test.extend(list(noise_test))
        # images_adv_test.extend(list(image_adversarial_test))
        noises_test.extend(list(noise_test.cpu().data.numpy()))
        images_adv_test.extend(list(image_adversarial_test.cpu().data.numpy()))
        images_clean_test.extend(list(x_input.cpu().data.numpy()))
        y_trues_clean_test.extend(list(y_true.cpu().data.numpy()))

        # print(noises_test)

    total_images_test = len(test_datasets)
    acc_test_clean = y_correct_test / total_images_test * 100
    acc_test_adv = y_correct_test_adversarial / total_images_test * 100
    total_images_train = len(train_datasets)
    acc_train_clean = y_correct_train / total_images_train * 100
    acc_train_adv = y_correct_train_adversarial / total_images_train * 100

    print("Train Set Accuracy Before: %.2f%% | Accuracy After: %.2f%%" %
          (acc_train_clean, acc_train_adv))
    print("Train Set Total misclassification: %d" %
          (total_images_train - y_correct_train_adversarial))

    print("Test Set Accuracy Before: %.2f%% | Accuracy After: %.2f%%" %
          (acc_test_clean, acc_test_adv))
    print("Test Set Total misclassification: %d" %
          (total_images_test - y_correct_test_adversarial))

    with open(args.output_path_acc, "w") as f1:
        f1.write("Train Set Accuracy Before: %.2f%% | Accuracy After: %.2f%%" %
                 (acc_train_clean, acc_train_adv))
        f1.write("\n")
        f1.write("Test Set Accuracy Before: %.2f%% | Accuracy After: %.2f%%" %
                 (acc_test_clean, acc_test_adv))

    with open(args.output_path_train, "wb") as f2:
        adv_data_dict = {
            "images_clean": images_clean_train,
            "images_adv": images_adv_train,
            "labels": y_trues_clean_train,
            "y_preds": y_preds_train,
            "noises": noises_train,
            "y_preds_adversarial": y_preds_train_adversarial,
        }
        pickle.dump(adv_data_dict, f2)

    with open(args.output_path_test, "wb") as f3:
        adv_data_dict = {
            "images_clean": images_clean_test,
            "images_adv": images_adv_test,
            "labels": y_trues_clean_test,
            "y_preds": y_preds_test,
            "noises": noises_test,
            "y_preds_adversarial": y_preds_test_adversarial,
        }
        pickle.dump(adv_data_dict, f3)
Exemplo n.º 26
0
def load_model(opt):
    if opt.pretrained_file != "":
        model = torch.load(opt.pretrained_file)
    else:
        if opt.model_def == 'alexnet':
            model = alexnet.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'bincifar':
            model = bincifar.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'bincifarfbin':
            model = bincifarfbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'densenet':
            model = densenet.DenseNet3(32, 10)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'alexnetfbin':
            model = alexnetfbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'alexnethybrid':
            model = alexnethybrid.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'alexnethybridv2':
            model = alexnethybridv2.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'alexnetwbin':
            model = alexnetwbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'googlenet':
            model = googlenet.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'googlenetfbin':
            model = googlenetfbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'googlenetwbin':
            model = googlenetwbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'mobilenet':
            model = mobilenet.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'nin':
            model = nin.Net()
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'resnet18':
            model = resnet.resnet18(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'resnetfbin18':
            model = resnetfbin.resnet18(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'resnethybrid18':
            model = resnethybrid.resnet18(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'resnethybridv218':
            model = resnethybridv2.resnet18(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'resnethybridv318':
            model = resnethybridv3.resnet18(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'resnetwbin18':
            model = resnetwbin.resnet18(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'sketchanet':
            model = sketchanet.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'sketchanetfbin':
            model = sketchanetfbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'sketchanethybrid':
            model = sketchanethybrid.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'sketchanethybridv2':
            model = sketchanethybridv2.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'sketchanetwbin':
            model = sketchanetwbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'squeezenet':
            model = squeezenet.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'squeezenetfbin':
            model = squeezenetfbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'squeezenethybrid':
            model = squeezenethybrid.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'squeezenethybridv2':
            model = squeezenethybridv2.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'squeezenethybridv3':
            model = squeezenethybridv3.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'squeezenetwbin':
            model = squeezenetwbin.Net(opt.nClasses)
            if opt.cuda:
                model = model.cuda()

        elif opt.model_def == 'vgg16_bncifar':
            model = vgg.vgg16_bn()
            if opt.cuda:
                model = model.cuda()

    return model
Exemplo n.º 27
0
def get_network(args):
    """ return given network
    """

    if args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    elif args.net == 'googlenet':
        from models.googLeNet import GoogLeNet
        net = GoogLeNet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import Inceptionv3
        net = Inceptionv3()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'wrn':
        from models.wideresnet import wideresnet
        net = wideresnet()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:
        print("use gpu")
        net = net.cuda()

    return net
Exemplo n.º 28
0
# alexnet = models.alexnet(pretrained=True)
# resnet18 = models.resnet18(pretrained=True).to(device)
# resnet18.device = device
# vgg16 = models.vgg16(pretrained=True)
# densenet = models.densenet161(pretrained=True)
# squeezenet = models.squeezenet1_0(pretrained=True)

resnet18 = resnet.resnet18(pretrained=True).to(device)
resnet18.device = device
resnet18.name = "resnet18"

densenet = densenet.densenet121(pretrained=True).to(device)
densenet.device = device
densenet.name = "densenet"

vgg16 = vgg.vgg16_bn(pretrained=True).to(device)
vgg16.device = device
vgg16.name = "vgg16"

net_list = [resnet18, densenet, vgg16]

normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                 std=[0.2023, 0.1994, 0.2010])
transform = transforms.Compose([transforms.ToTensor(), normalize])

testset = datasets.CIFAR10(root="./data",
                           train=False,
                           download=True,
                           transform=transform)

testloader = torch.utils.data.DataLoader(testset,
Exemplo n.º 29
0
def FGSM(best_cla_model_path, device_used):
    device = torch.device(device_used if torch.cuda.is_available() else "cpu")

    parser = argparse.ArgumentParser("Adversarial Examples")
    parser.add_argument(
        "--input_path",
        type=str,
        default="D:/python_workplace/resnet-AE/inputData/mnist/",
        help="data set dir path")
    parser.add_argument(
        "--input_path_train_pkl",
        type=str,
        default="D:/python_workplace/resnet-AE/inputData/mnist/mnist_train.pkl",
        help="data set dir path")
    parser.add_argument(
        "--input_path_test_pkl",
        type=str,
        default="D:/python_workplace/resnet-AE/inputData/mnist/mnist_test.pkl",
        help="data set dir path")
    parser.add_argument(
        "--output_path_train",
        type=str,
        default=
        "D:/python_workplace/resnet-AE/outputData/FGSM/vgg/mnist/train/train.pkl",
        help="Output directory with train images.")
    parser.add_argument(
        "--output_path_test",
        type=str,
        default=
        "D:/python_workplace/resnet-AE/outputData/FGSM/vgg/mnist/test/test.pkl",
        help="Output directory with test images.")
    parser.add_argument("--epsilon", type=float, default=0.4, help="Epsilon")
    parser.add_argument("--L_F", type=int, default=5, help="L_F")
    parser.add_argument("--image_size",
                        type=int,
                        default=28,
                        help="Width of each input images.")
    parser.add_argument("--batch_size",
                        type=int,
                        default=1000,
                        help="How many images process at one time.")
    parser.add_argument("--num_classes",
                        type=int,
                        default=10,
                        help="num classes")
    parser.add_argument(
        "--output_path_acc",
        type=str,
        default=
        "D:/python_workplace/resnet-AE/outputData/FGSM/vgg/mnist/acc.txt",
        help="Output directory with acc file.")

    args = parser.parse_args()

    # Transform Init
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Data Parse
    train_datasets = torchvision.datasets.MNIST(root=args.input_path,
                                                transform=transform_train,
                                                download=True,
                                                train=True)

    train_loader = torch.utils.data.DataLoader(dataset=train_datasets,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    test_datasets = torchvision.datasets.MNIST(root=args.input_path,
                                               transform=transform_test,
                                               download=True,
                                               train=False)

    test_loader = torch.utils.data.DataLoader(dataset=test_datasets,
                                              batch_size=args.batch_size,
                                              shuffle=True)

    # Define Network
    model = vgg.vgg16_bn(pretrained=False)

    # Load pre-trained weights
    model.load_state_dict(torch.load(best_cla_model_path))
    print("Weights Loaded!")
    model.to(device)
    # criterion
    criterion = nn.CrossEntropyLoss().to(device)

    # adversarial examples of train set
    noises_train = []
    y_preds_train = []
    y_preds_train_adversarial = []
    y_correct_train = 0
    y_correct_train_adversarial = 0
    images_clean_train = []
    images_adv_train = []
    y_trues_clean_train = []

    for data in train_loader:
        x_input, y_true = data
        x_input, y_true = x_input.to(device), y_true.to(device)
        x_input.requires_grad_()

        # Forward pass
        model.eval()
        outputs = model(x_input)
        loss = criterion(outputs, y_true)
        # print(y_true.cpu().data.numpy())
        loss.backward()  # obtain gradients on x

        # Classification before Adv
        _, y_pred = torch.max(outputs.data, 1)
        y_correct_train += y_pred.eq(y_true.data).cpu().sum().item()

        # Generate Adversarial Image
        # Add perturbation
        epsilon = args.epsilon
        x_grad = torch.sign(x_input.grad.data)
        x_adversarial = torch.clamp(x_input.data + epsilon * x_grad, 0,
                                    1).to(device)
        # x_adversarial = (x_input.data + epsilon * x_grad).to(device)

        image_origin_train = x_input.cpu().data.numpy() * 255
        image_origin_train = np.rint(image_origin_train).astype(np.int)

        image_adversarial_train = x_adversarial.cpu().data.numpy() * 255
        image_adversarial_train = np.rint(image_adversarial_train).astype(
            np.int)

        noise_train = image_adversarial_train - image_origin_train
        # noise_train = np.where(noise_train >= args.L_F, args.L_F, noise_train)
        # noise_train = np.where(noise_train <= -args.L_F, args.L_F, noise_train)

        image_adversarial_train = noise_train + image_origin_train

        noise_train = noise_train / 255
        image_adversarial_train = image_adversarial_train / 255

        # Classification after optimization
        outputs_adversarial = model(
            Variable(
                torch.from_numpy(image_adversarial_train).type(
                    torch.FloatTensor).to(device)))
        _, y_pred_adversarial = torch.max(outputs_adversarial.data, 1)
        y_correct_train_adversarial += y_pred_adversarial.eq(
            y_true.data).cpu().sum().item()

        y_preds_train.extend(list(y_pred.cpu().data.numpy()))
        y_preds_train_adversarial.extend(
            list(y_pred_adversarial.cpu().data.numpy()))
        noises_train.extend(list(noise_train))
        images_adv_train.extend(list(image_adversarial_train))
        images_clean_train.extend(list(x_input.cpu().data.numpy()))
        y_trues_clean_train.extend(list(y_true.cpu().data.numpy()))

        # print(x_input.data.cpu().numpy())
        # print(noises_train)

    # adversarial examples of test set
    noises_test = []
    y_preds_test = []
    y_preds_test_adversarial = []
    y_correct_test = 0
    y_correct_test_adversarial = 0
    images_adv_test = []
    images_clean_test = []
    y_trues_clean_test = []

    for data in test_loader:
        x_input, y_true = data
        x_input, y_true = x_input.to(device), y_true.to(device)
        x_input.requires_grad_()

        # Forward pass
        model.eval()
        outputs = model(x_input)
        loss = criterion(outputs, y_true)
        # print(y_true.cpu().data.numpy())
        loss.backward()  # obtain gradients on x

        # Classification before Adv
        _, y_pred = torch.max(outputs.data, 1)
        y_correct_test += y_pred.eq(y_true.data).cpu().sum().item()

        # Generate Adversarial Image
        # Add perturbation
        epsilon = args.epsilon
        x_grad = torch.sign(x_input.grad.data)
        x_adversarial = torch.clamp(x_input.data + epsilon * x_grad, 0,
                                    1).to(device)
        # x_adversarial = (x_input.data + epsilon * x_grad).to(device)

        image_origin_test = x_input.cpu().data.numpy() * 255
        image_origin_test = np.rint(image_origin_test).astype(np.int)

        image_adversarial_test = x_adversarial.cpu().data.numpy() * 255
        image_adversarial_test = np.rint(image_adversarial_test).astype(np.int)

        noise_test = image_adversarial_test - image_origin_test
        # noise_test = np.where(noise_test >= args.L_F, args.L_F, noise_test)
        # noise_test = np.where(noise_test <= -args.L_F, args.L_F, noise_test)

        image_adversarial_test = noise_test + image_origin_test

        image_adversarial_test = image_adversarial_test / 255
        noise_test = noise_test / 255

        # Classification after optimization
        outputs_adversarial = model(
            Variable(
                torch.from_numpy(image_adversarial_test).type(
                    torch.FloatTensor).to(device)))
        _, y_pred_adversarial = torch.max(outputs_adversarial.data, 1)
        y_correct_test_adversarial += y_pred_adversarial.eq(
            y_true.data).cpu().sum().item()

        y_preds_test.extend(list(y_pred.cpu().data.numpy()))
        y_preds_test_adversarial.extend(
            list(y_pred_adversarial.cpu().data.numpy()))
        noises_test.extend(list(noise_test))
        images_adv_test.extend(list(image_adversarial_test))
        images_clean_test.extend(list(x_input.cpu().data.numpy()))
        y_trues_clean_test.extend(list(y_true.cpu().data.numpy()))

        # print(noises_test)

    total_images_test = len(test_datasets)
    acc_test_clean = y_correct_test / total_images_test * 100
    acc_test_adv = y_correct_test_adversarial / total_images_test * 100
    total_images_train = len(train_datasets)
    acc_train_clean = y_correct_train / total_images_train * 100
    acc_train_adv = y_correct_train_adversarial / total_images_train * 100

    print("Train Set Accuracy Before: %.2f%% | Accuracy After: %.2f%%" %
          (acc_train_clean, acc_train_adv))
    print("Train Set Total misclassification: %d" %
          (total_images_train - y_correct_train_adversarial))

    print("Test Set Accuracy Before: %.2f%% | Accuracy After: %.2f%%" %
          (acc_test_clean, acc_test_adv))
    print("Test Set Total misclassification: %d" %
          (total_images_test - y_correct_test_adversarial))

    with open(args.output_path_acc, "w") as f1:
        f1.write("Train Set Accuracy Before: %.2f%% | Accuracy After: %.2f%%" %
                 (acc_train_clean, acc_train_adv))
        f1.write("\n")
        f1.write("Test Set Accuracy Before: %.2f%% | Accuracy After: %.2f%%" %
                 (acc_test_clean, acc_test_adv))

    with open(args.output_path_train, "wb") as f2:
        adv_data_dict = {
            "images_clean": images_clean_train,
            "images_adv": images_adv_train,
            "labels": y_trues_clean_train,
            "y_preds": y_preds_train,
            "noises": noises_train,
            "y_preds_adversarial": y_preds_train_adversarial,
        }
        pickle.dump(adv_data_dict, f2)

    with open(args.output_path_test, "wb") as f3:
        adv_data_dict = {
            "images_clean": images_clean_test,
            "images_adv": images_adv_test,
            "labels": y_trues_clean_test,
            "y_preds": y_preds_test,
            "noises": noises_test,
            "y_preds_adversarial": y_preds_test_adversarial,
        }
        pickle.dump(adv_data_dict, f3)
Exemplo n.º 30
0
def main():
    logger_init()
    dataset_type = config.DATASET
    batch_size = config.BATCH_SIZE

    # Dataset setting
    logger.info("Initialize the dataset...")
    train_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][0],\
                                      {mask_type:config.DATA_FLIST[config.MASKDATASET][mask_type][0] for mask_type in config.MASK_TYPES}, \
                                      resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \
                                      random_bbox_margin=config.RANDOM_BBOX_MARGIN,
                                      random_ff_setting=config.RANDOM_FF_SETTING)
    train_loader = train_dataset.loader(batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=16)

    val_dataset = InpaintDataset(config.DATA_FLIST[dataset_type][1],\
                                    {mask_type:config.DATA_FLIST[config.MASKDATASET][mask_type][1] for mask_type in ('val',)}, \
                                    resize_shape=tuple(config.IMG_SHAPES), random_bbox_shape=config.RANDOM_BBOX_SHAPE, \
                                    random_bbox_margin=config.RANDOM_BBOX_MARGIN,
                                    random_ff_setting=config.RANDOM_FF_SETTING)
    val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1)
    ### Generate a new val data
    val_datas = []
    j = 0
    for i, data in enumerate(val_loader):
        if j < config.STATIC_VIEW_SIZE:
            imgs = data[0]
            if imgs.size(1) == 3:
                val_datas.append(data)
                j += 1
        else:
            break
    #val_datas = [(imgs, masks) for imgs, masks in val_loader]

    val_loader = val_dataset.loader(batch_size=1, shuffle=False, num_workers=1)
    logger.info("Finish the dataset initialization.")

    # Define the Network Structure
    logger.info("Define the Network Structure and Losses")
    netG = InpaintRUNNet(cuda0, n_in_channel=config.N_CHANNEL)
    netD = InpaintSADirciminator()
    netVGG = vgg16_bn(pretrained=True)
    sr_args = SRArgs(config.GPU_IDS[0])
    netSR = sr_model.Model(sr_args, sr_util.checkpoint(sr_args))

    if config.MODEL_RESTORE != '':
        whole_model_path = 'model_logs/{}'.format(config.MODEL_RESTORE)
        nets = torch.load(whole_model_path)
        netG_state_dict, netD_state_dict = nets['netG_state_dict'], nets[
            'netD_state_dict']
        netG.load_state_dict(netG_state_dict)
        netD.load_state_dict(netD_state_dict)
        logger.info("Loading pretrained models from {} ...".format(
            config.MODEL_RESTORE))

    # Define loss
    recon_loss = ReconLoss(*(config.L1_LOSS_ALPHA))
    gan_loss = SNGenLoss(config.GAN_LOSS_ALPHA)
    perc_loss = PerceptualLoss(weight=config.PERC_LOSS_ALPHA,
                               feat_extractors=netVGG.to(cuda1))
    style_loss = StyleLoss(weight=config.STYLE_LOSS_ALPHA,
                           feat_extractors=netVGG.to(cuda1))
    dis_loss = SNDisLoss()
    lr, decay = config.LEARNING_RATE, config.WEIGHT_DECAY
    optG = torch.optim.Adam(netG.parameters(), lr=lr, weight_decay=decay)
    optD = torch.optim.Adam(netD.parameters(), lr=4 * lr, weight_decay=decay)

    nets = {"netG": netG, "netD": netD, "vgg": netVGG, "netSR": netSR}

    losses = {
        "GANLoss": gan_loss,
        "ReconLoss": recon_loss,
        "StyleLoss": style_loss,
        "DLoss": dis_loss,
        "PercLoss": perc_loss
    }

    opts = {
        "optG": optG,
        "optD": optD,
    }

    logger.info("Finish Define the Network Structure and Losses")

    # Start Training
    logger.info("Start Training...")
    epoch = 50

    for i in range(epoch):
        #validate(netG, netD, gan_loss, recon_loss, dis_loss, optG, optD, val_loader, i, device=cuda0)

        #train data
        train(nets,
              losses,
              opts,
              train_loader,
              i,
              devices=(cuda0, cuda1),
              val_datas=val_datas)

        # validate
        validate(nets, losses, opts, val_datas, i, devices=(cuda0, cuda1))

        saved_model = {
            'epoch': i + 1,
            'netG_state_dict': netG.to(cpu0).state_dict(),
            'netD_state_dict': netD.to(cpu0).state_dict(),
            # 'optG' : optG.state_dict(),
            # 'optD' : optD.state_dict()
        }
        torch.save(saved_model,
                   '{}/epoch_{}_ckpt.pth.tar'.format(log_dir, i + 1))
        torch.save(saved_model,
                   '{}/latest_ckpt.pth.tar'.format(log_dir, i + 1))