コード例 #1
0
def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model
コード例 #2
0
ファイル: vgg16_depth.py プロジェクト: sharifza/Depth-VRD
def vgg16_depth(pretrained=False, **kwargs):
    """VGG 16-layer depth model (configuration "D")

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = VGG(make_layers(cfg['D']), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))

    return model
コード例 #3
0
def vgg16(pretrained=True, **kwargs):
    """VGG 16-layer model (configuration "D") - only the conv part

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    conf = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
    if pretrained:
        kwargs['init_weights'] = False

    model = VGG(make_layers(conf), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))

    # Froze bottom layers
    for module in model.features[0:10]:
        for param in module.parameters():
            param.requires_grad = False

    return model.features
コード例 #4
0
    def __init__(self,
                 pretrained=True,
                 model='vgg16',
                 requires_grad=True,
                 remove_fc=True,
                 show_params=False):
        VGG.__init__(self, make_layers(cfg[model]))
        self.ranges = ranges[model]

        if pretrained:
            exec(
                "self.load_state_dict(models.%s(pretrained=True).state_dict())"
                % model)

        if not requires_grad:
            for param in super().parameters():
                param.requires_grad = False

        if remove_fc:  # delete redundant fully-connected layer params, can save memory
            del self.classifier

        if show_params:
            for name, param in self.named_parameters():
                print(name, param.size())
コード例 #5
0
    def __init__(self, batch_norm=True, pretrained=True):
        super(Extractor, self).__init__()
        vgg = VGG(make_layers(cfg['D'], batch_norm))

        if pretrained:
            if batch_norm:
                vgg.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
            else:
                vgg.load_state_dict(model_zoo.load_url(model_urls['vgg16']))

        self.features = vgg.features
コード例 #6
0
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


cfg = [
    64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512,
    512, 512, 'M'
]
model = VGG(make_layers(cfg, batch_norm=True))
path_pre_model = "/home/jjchu/My_Research/SegGAN/pre_model/vgg16_bn-6c64b313.pth"
model.load_state_dict(torch.load(path_pre_model))

# torch.load(path)
# ('features.41.running_mean',)
# ('classifier.6.bias',)

# model_ft = prevgg16(pretrained=True)

# num_ftrs = model_ft.fc.in_features
# model_ft.fc = nn.Linear(num_ftrs, 2)

if use_gpu:
    model_ft = model_ft.cuda()
コード例 #7
0
def main():
    global args, best_prec1, train_log, test_log
    args = parser.parse_args()

    dir_name = args.arch + '_' + datetime.datetime.now().strftime('%m%d_%H%M')
    log_dir = os.path.join('logs', os.path.join('prune', dir_name))
    checkpoint_dir = os.path.join('checkpoints', os.path.join('prune', dir_name))
    os.makedirs(log_dir)
    os.makedirs(checkpoint_dir)
    train_log = Logger(os.path.join(log_dir, 'train.log'))
    test_log = Logger(os.path.join(log_dir, 'test.log'))
    config_log = Logger(os.path.join(log_dir, 'config.log'))

    for k, v in vars(args).items():
        config_log.write(content="{k} : {v}".format(k=k, v=v), wrap=True, flush=True)
    config_log.close()

    # create model
    print("=" * 89)
    print("=> creating model '{}'".format(args.arch))

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']

            vgg_cfg, batch_norm = checkpoint['cfg']
            from torchvision.models.vgg import VGG, make_layers
            model = VGG(make_layers(cfg=vgg_cfg, batch_norm=batch_norm), init_weights=False)
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
            model.load_state_dict(checkpoint['state_dict'])

            optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            return

    elif args.pretrained:
        if args.pretrained == 'pytorch':
            print("=> using pre-trained model from model zoo")
            model = models.__dict__[args.arch](pretrained=True)
            args.pretrained_parallel = False
        else:
            model = models.__dict__[args.arch]()
            if args.pretrained_parallel:
                model.features = torch.nn.DataParallel(model.features)
                model.cuda()
            if os.path.isfile(args.pretrained):
                print("=> using pre-trained model '{}'".format(args.pretrained))
                checkpoint = torch.load(args.pretrained)
                model.load_state_dict(checkpoint['state_dict'])
                if not args.pretrained_parallel:
                    model.features = torch.nn.DataParallel(model.features)
                    model.cuda()
            else:
                print("=> no checkpoint found at '{}'".format(args.pretrained))
                return
    else:
        model = models.__dict__[args.arch]()
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()

    # define loss function (criterion)
    criterion = nn.CrossEntropyLoss().cuda()

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if not args.resume:
        rcn_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(valdir, transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.rcn_batch_size, shuffle=True,
            num_workers=args.workers, pin_memory=True)

        prune(train_loader=train_loader, val_loader=val_loader, rcn_loader=rcn_loader,
              model=model, criterion=criterion)

        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(lr_decay_step=args.lr_decay_step,
                             optimizer=optimizer, epoch=epoch)

        # train for one epoch
        train(train_loader=train_loader, model=model, criterion=criterion,
              optimizer=optimizer, epoch=epoch, log=True)

        # evaluate on validation set
        prec1 = validate(val_loader=val_loader, model=model,
                         criterion=criterion, epoch=epoch, log=True)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'cfg': get_vgg_cfg(model),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
        }, is_best=is_best, checkpoint_dir=checkpoint_dir)
コード例 #8
0
ファイル: get_cnn.py プロジェクト: yazici/graph_distillation
def get_vgg(in_channels=3, **kwargs):
  model = VGG(make_layers(cfg['D'], in_channels), **kwargs)
  return model