Example #1
0
def get_model(train_model):

    if train_model == 'resnet18':
        return resnet.resnet18()
    elif train_model == 'resnet34':
        return resnet.resnet34()
    elif train_model == 'resnet50':
        return resnet.resnet50()
    elif train_model == 'resnet101':
        return resnet.resnet101()
    elif train_model == 'resnet152':
        return resnet.resnet152()
    elif train_model == 'resnet18_copy':
        return resnet_copy.resnet18()
    elif train_model == 'resnet34_copy':
        return resnet_copy.resnet34()
    elif train_model == 'resnet50_copy':
        return resnet_copy.resnet50()
    elif train_model == 'resnet101_copy':
        return resnet_copy.resnet101()
    elif train_model == 'resnet152':
        return resnet_copy.resnet152()
    elif train_model == 'vgg11':
        return vgg11()
    elif train_model == 'vgg13':
        return vgg13()
    elif train_model == 'vgg16':
        return vgg16()
    elif train_model == 'vgg19':
        return vgg19()
    elif train_model == 'nin':
        return nin()
    elif train_model == 'googlenet':
        return googlenet()
def crowd_counting(dataloader, model_param_path, savecsv):
    '''
    Show one estimated density-map.
    img_root: the root of test image data.
    gt_dmap_root: the root of test ground truth density-map data.
    model_param_path: the path of specific mcnn parameters.
    index: the order of the test image in test dataset.
    '''
    device = torch.device("cuda")

    model = vgg19()

    model.load_state_dict(torch.load(model_param_path))
    model.to(device)

    output_dict = {"file": [], "man_count": []}

    model.eval()
    for i, (inputs, name) in enumerate(tqdm(dataloader)):
        inputs = inputs.to(device)
        et_dmap = model(inputs).detach()
        et_count = torch.sum(et_dmap).item()
        output_dict["file"].append(name[0])
        if et_count >= 100:
            et_count = 100
        output_dict["man_count"].append(int(et_count))

    pd.DataFrame(output_dict).to_csv(savecsv,
                                     index=False,
                                     columns=["file", "man_count"])
def cal_mae(dataloader, model_param_path):
    '''
    Calculate the MAE of the test data.
    img_root: the root of test image data.
    gt_dmap_root: the root of test ground truth density-map data.
    model_param_path: the path of specific mcnn parameters.
    '''
    device = torch.device("cuda")

    model = vgg19()
    model.load_state_dict(torch.load(model_param_path))
    model.to(device)

    model.eval()
    epoch_minus = []
    with torch.no_grad():
        for i, (inputs, count, name) in enumerate(tqdm(dataloader)):
            inputs = inputs.to(device)
            outputs = model(inputs)
            temp_minu = count[0].item() - torch.sum(outputs).item()
            epoch_minus.append(temp_minu)

    epoch_minus = np.array(epoch_minus)
    mse = np.sqrt(np.mean(np.square(epoch_minus)))
    mae = np.mean(np.abs(epoch_minus))
    log_str = 'Final Test: mae {}, mse {}'.format(mae, mse)
    print(log_str)
    def setup(self):
        """initial the datasets, model, loss and optimizer"""
        args = self.args
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.device_count = torch.cuda.device_count()
            # for code conciseness, we release the single gpu version
            assert self.device_count == 1
            logging.info('using {} gpus'.format(self.device_count))
        else:
            raise Exception("gpu is not available")

        self.downsample_ratio = args.downsample_ratio
        self.datasets = {
            x: Crowd(os.path.join(args.data_dir, x), args.crop_size,
                     args.downsample_ratio, args.is_gray, x)
            for x in ['train', 'val']
        }
        self.dataloaders = {
            x: DataLoader(self.datasets[x],
                          collate_fn=(train_collate
                                      if x == 'train' else default_collate),
                          batch_size=(args.batch_size if x == 'train' else 1),
                          shuffle=(True if x == 'train' else False),
                          num_workers=args.num_workers * self.device_count,
                          pin_memory=(True if x == 'train' else False))
            for x in ['train', 'val']
        }
        self.model = vgg19()
        self.model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.weight_decay)

        self.start_epoch = 0
        if args.resume:
            suf = args.resume.rsplit('.', 1)[-1]
            if suf == 'tar':
                checkpoint = torch.load(args.resume, self.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(
                    checkpoint['optimizer_state_dict'])
                self.start_epoch = checkpoint['epoch'] + 1
            elif suf == 'pth':
                self.model.load_state_dict(torch.load(args.resume,
                                                      self.device))

        self.post_prob = Post_Prob(args.sigma, args.crop_size,
                                   args.downsample_ratio,
                                   args.background_ratio, args.use_background,
                                   self.device)
        self.criterion = Bay_Loss(args.use_background, self.device)
        self.save_list = Save_Handle(max_num=args.max_model_num)
        self.best_mae = np.inf
        self.best_mse = np.inf
        self.best_mae_1 = np.inf
        self.best_mse_1 = np.inf
        self.best_count = 0
        self.best_count_1 = 0
Example #5
0
def get_network(args):

    if args.net == 'vgg16':
        from models.vgg import vgg16
        model_ft = vgg16(args.num_classes, export_onnx=args.export_onnx)
    elif args.net == 'alexnet':
        from models.alexnet import alexnet
        model_ft = alexnet(num_classes=args.num_classes,
                           export_onnx=args.export_onnx)
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet_v2
        model_ft = mobilenet_v2(pretrained=True, export_onnx=args.export_onnx)
    elif args.net == 'vgg19':
        from models.vgg import vgg19
        model_ft = vgg19(args.num_classes, export_onnx=args.export_onnx)
    else:
        if args.net == 'googlenet':
            from models.googlenet import googlenet
            model_ft = googlenet(pretrained=True)
        elif args.net == 'inception':
            from models.inception import inception_v3
            model_ft = inception_v3(args,
                                    pretrained=True,
                                    export_onnx=args.export_onnx)
        elif args.net == 'resnet18':
            from models.resnet import resnet18
            model_ft = resnet18(pretrained=True, export_onnx=args.export_onnx)
        elif args.net == 'resnet34':
            from models.resnet import resnet34
            model_ft = resnet34(pretrained=True, export_onnx=args.export_onnx)
        elif args.net == 'resnet101':
            from models.resnet import resnet101
            model_ft = resnet101(pretrained=True, export_onnx=args.export_onnx)
        elif args.net == 'resnet50':
            from models.resnet import resnet50
            model_ft = resnet50(pretrained=True, export_onnx=args.export_onnx)
        elif args.net == 'resnet152':
            from models.resnet import resnet152
            model_ft = resnet152(pretrained=True, export_onnx=args.export_onnx)
        else:
            print("The %s is not supported..." % (args.net))
            return
    if args.net == 'mobilenet':
        num_ftrs = model_ft.classifier[1].in_features
        model_ft.classifier[1] = nn.Linear(num_ftrs * 4, args.num_classes)
    else:
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, args.num_classes)
    net = model_ft

    return net
Example #6
0
def get_network(args,cfg):
    """ return given network
    """
    # pdb.set_trace()
    if args.net == 'lenet5':
        net = LeNet5().cuda()
    elif args.net == 'alexnet':
        net = alexnet(pretrained=args.pretrain, num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg16':
        net = vgg16(pretrained=args.pretrain, num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg13':
        net = vgg13(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg11':
        net = vgg11(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg19':
        net = vgg19(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg16_bn':
        net = vgg16_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg13_bn':
        net = vgg13_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg11_bn':
        net = vgg11_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg19_bn':
        net = vgg19_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net =='inceptionv3':
        net = inception_v3().cuda()
    # elif args.net == 'inceptionv4':
    #     net = inceptionv4().cuda()
    # elif args.net == 'inceptionresnetv2':
    #     net = inception_resnet_v2().cuda()
    elif args.net == 'resnet18':
        net = resnet18(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda(args.gpuid)
    elif args.net == 'resnet34':
        net = resnet34(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'resnet50':
        net = resnet50(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda(args.gpuid)
    elif args.net == 'resnet101':
        net = resnet101(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'resnet152':
        net = resnet152(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'squeezenet':
        net = squeezenet1_0().cuda()
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    return net
Example #7
0
def get_network(args):
    """ return given network
    """
    if args.task == 'cifar10':
        nclass = 10
    elif args.task == 'cifar100':
        nclass = 100
    #Yang added none bn vggs
    if args.net == 'vgg11':
        from models.vgg import vgg11
        net = vgg11(num_classes=nclass)
    elif args.net == 'vgg13':
        from models.vgg import vgg13
        net = vgg13(num_classes=nclass)
    elif args.net == 'vgg16':
        from models.vgg import vgg16
        net = vgg16(num_classes=nclass)
    elif args.net == 'vgg19':
        from models.vgg import vgg19
        net = vgg19(num_classes=nclass) 
    
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34(num_classes=nclass)
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50(num_classes=nclass)
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101(num_classes=nclass)
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152(num_classes=nclass)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu: #use_gpu
        net = net.cuda()

    return net
Example #8
0
def get_network(args):

    if args.net == 'vgg16':
        from models.vgg import vgg16
        net = vgg16()

    elif args.net == 'vgg11':
        from models.vgg import vgg11
        net = vgg11()

    elif args.net == 'vgg13':
        from models.vgg import vgg13
        net = vgg13()

    elif args.net == 'vgg19':
        from models.vgg import vgg19
        net = vgg19()

    return net
def estimate_density_map(dataloader,
                         model_param_path,
                         index=None,
                         saveroot=None):
    '''
    Show one estimated density-map.
    img_root: the root of test image data.
    gt_dmap_root: the root of test ground truth density-map data.
    model_param_path: the path of specific mcnn parameters.
    index: the order of the test image in test dataset.
    '''
    device = torch.device("cuda")

    model = vgg19()

    model.load_state_dict(torch.load(model_param_path))
    model.to(device)

    if not os.path.exists(saveroot):
        os.makedirs(saveroot)

    model.eval()
    for i, (inputs, count, name) in enumerate(tqdm(dataloader)):
        inputs = inputs.to(device)
        if index is not None:
            if i == index:
                # forward propagation
                et_dmap = model(inputs).detach()
                et_dmap = et_dmap.squeeze(0).squeeze(0).cpu().numpy()
                # plt.imshow(et_dmap,cmap=CM.jet)
                # plt.show()
                plt.imsave(os.path.join(saveroot, '{}.png'.format(name[0])),
                           et_dmap,
                           cmap=CM.jet)
                break
            else:
                continue
        else:
            et_dmap = model(inputs).detach()
            et_dmap = et_dmap.squeeze(0).squeeze(0).cpu().numpy()
            plt.imsave(os.path.join(saveroot, '{}.png'.format(name[0])),
                       et_dmap,
                       cmap=CM.jet)
Example #10
0
    def setup(self):
        """initial the datasets, model, loss and optimizer"""
        args = self.args
        self.skip_test = args.skip_test
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.device_count = torch.cuda.device_count()
            # for code conciseness, we release the single gpu version
            assert self.device_count == 1
            logging.info('using {} gpus'.format(self.device_count))
        else:
            raise Exception("gpu is not available")

        self.downsample_ratio = args.downsample_ratio
        lists = {}
        train_list = None
        val_list = None
        test_list = None
        lists['train'] = train_list
        lists['val'] = val_list
        lists['test'] = test_list
        self.datasets = {x: Crowd(os.path.join(args.data_dir, x),
                                  args.crop_size,
                                  args.downsample_ratio,
                                  args.is_gray, x, args.resize,
                                  im_list=lists[x]) for x in ['train', 'val']}
        self.dataloaders = {x: DataLoader(self.datasets[x],
                                          collate_fn=(train_collate
                                                      if x == 'train' else default_collate),
                                          batch_size=(args.batch_size
                                          if x == 'train' else 1),
                                          shuffle=(True if x == 'train' else False),
                                          num_workers=args.num_workers*self.device_count,
                                          pin_memory=(True if x == 'train' else False))
                            for x in ['train', 'val']}
        self.datasets['test'] = Crowd(os.path.join(args.data_dir, 'test'),
                                    args.crop_size,
                                    args.downsample_ratio,
                                    args.is_gray, 'val', args.resize, 
                                    im_list=lists['test'])
        self.dataloaders['test'] = DataLoader(self.datasets['test'],
                                    collate_fn=default_collate,
                                    batch_size=1,
                                    shuffle=False,
                                    num_workers=args.num_workers*self.device_count,
                                    pin_memory=False)
        print(len(self.dataloaders['train']))
        print(len(self.dataloaders['val']))

        if self.args.net == 'csrnet':
            self.model = CSRNet()
        else:
            self.model = vgg19()

        self.refiner = IndivBlur8(s=args.s, downsample=self.downsample_ratio, softmax=args.soft)
        refine_params = list(self.refiner.adapt.parameters())

        self.model.to(self.device)
        self.refiner.to(self.device)
        params = list(self.model.parameters()) 
        self.optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
        # self.optimizer = optim.SGD(params, lr=args.lr, momentum=0.95, weight_decay=args.weight_decay)
        self.dml_optimizer = torch.optim.Adam(refine_params, lr=1e-7, weight_decay=args.weight_decay)

        self.start_epoch = 0
        if args.resume:
            suf = args.resume.rsplit('.', 1)[-1]
            if suf == 'tar':
                checkpoint = torch.load(args.resume, self.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.refiner.load_state_dict(checkpoint['refine_state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                self.start_epoch = checkpoint['epoch'] + 1
            elif suf == 'pth':
                self.model.load_state_dict(torch.load(args.resume, self.device))

        self.crit = torch.nn.MSELoss(reduction='sum')

        self.save_list = Save_Handle(max_num=args.max_model_num)
        self.test_flag = False
        self.best_mae = {}
        self.best_mse = {}
        self.best_epoch = {}
        for stage in ['val', 'test']:
            self.best_mae[stage] = np.inf
            self.best_mse[stage] = np.inf
            self.best_epoch[stage] = 0
Example #11
0
def test_model(args):
    # create model
    num_classes = 2
    if args.arch == 'efficientnet_b0':
        if args.pretrained:
            model = EfficientNet.from_pretrained("efficientnet-b0",
                                                 quantize=args.quantize,
                                                 num_classes=num_classes)
        else:
            model = EfficientNet.from_name(
                "efficientnet-b0",
                quantize=args.quantize,
                override_params={'num_classes': num_classes})
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'mobilenet_v1':
        model = mobilenet_v1(quantize=args.quantize, num_classes=num_classes)
        model = torch.nn.DataParallel(model).cuda()

        if args.pretrained:
            checkpoint = torch.load(args.resume)
            state_dict = checkpoint['state_dict']

            if num_classes != 1000:
                new_dict = {
                    k: v
                    for k, v in state_dict.items() if 'fc' not in k
                }
                state_dict = new_dict

            res = model.load_state_dict(state_dict, strict=False)

            for missing_key in res.missing_keys:
                assert 'quantize' in missing_key or 'fc' in missing_key

    elif args.arch == 'mobilenet_v2':
        model = mobilenet_v2(pretrained=args.pretrained,
                             num_classes=num_classes,
                             quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'resnet18':
        model = resnet18(pretrained=args.pretrained,
                         num_classes=num_classes,
                         quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'resnet50':
        model = resnet50(pretrained=args.pretrained,
                         num_classes=num_classes,
                         quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'resnet152':
        model = resnet152(pretrained=args.pretrained,
                          num_classes=num_classes,
                          quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'resnet164':
        model = resnet_164(num_classes=num_classes, quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'vgg11':
        model = vgg11(pretrained=args.pretrained,
                      num_classes=num_classes,
                      quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'vgg19':
        model = vgg19(pretrained=args.pretrained,
                      num_classes=num_classes,
                      quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    else:
        logging.info('No such model.')
        sys.exit()

    if args.resume and not args.pretrained:
        if os.path.isfile(args.resume):
            logging.info('=> loading checkpoint `{}`'.format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            logging.info('=> loaded checkpoint `{}` (epoch: {})'.format(
                args.resume, checkpoint['epoch']))
        else:
            logging.info('=> no checkpoint found at `{}`'.format(args.resume))

    cudnn.benchmark = False
    test_loader = prepare_test_data(dataset=args.dataset,
                                    datadir=args.datadir,
                                    batch_size=args.batch_size,
                                    shuffle=False,
                                    num_workers=args.workers)
    criterion = nn.CrossEntropyLoss().cuda()

    with torch.no_grad():
        prec1 = validate(args, test_loader, model, criterion, 0)
Example #12
0
    def setup(self):
        """initial the datasets, model, loss and optimizer"""
        args = self.args
        self.loss = args.loss
        self.skip_test = args.skip_test
        self.add = args.add
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.device_count = torch.cuda.device_count()
            # for code conciseness, we release the single gpu version
            assert self.device_count == 1
            logging.info('using {} gpus'.format(self.device_count))
        else:
            raise Exception("gpu is not available")

        self.downsample_ratio = args.downsample_ratio
        lists = {}
        train_list = None
        val_list = None
        test_list = None
        lists['train'] = train_list
        lists['val'] = val_list
        lists['test'] = test_list
        self.datasets = {
            x: Crowd(os.path.join(args.data_dir, x),
                     args.crop_size,
                     args.downsample_ratio,
                     args.is_gray,
                     x,
                     im_list=lists[x])
            for x in ['train', 'val']
        }
        self.dataloaders = {
            x: DataLoader(
                self.datasets[x],
                collate_fn=(train_collate
                            if x == 'train' else default_collate),
                batch_size=(self.args.batch_size if x == 'train' else 1),
                shuffle=(True if x == 'train' else False),
                num_workers=args.num_workers * self.device_count,
                pin_memory=(True if x == 'train' else False))
            for x in ['train', 'val']
        }
        self.datasets['test'] = Crowd(os.path.join(args.data_dir, 'test'),
                                      args.crop_size,
                                      args.downsample_ratio,
                                      args.is_gray,
                                      'val',
                                      im_list=lists['test'])
        self.dataloaders['test'] = DataLoader(self.datasets['test'],
                                              collate_fn=default_collate,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=args.num_workers *
                                              self.device_count,
                                              pin_memory=False)
        self.model = vgg19(down=self.downsample_ratio,
                           bn=args.bn,
                           o_cn=args.o_cn)
        self.model.to(self.device)
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(params, lr=args.lr)

        self.start_epoch = 0
        if args.resume:
            suf = args.resume.rsplit('.', 1)[-1]
            if suf == 'tar':
                checkpoint = torch.load(args.resume, self.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(
                    checkpoint['optimizer_state_dict'])
                self.start_epoch = checkpoint['epoch'] + 1
            elif suf == 'pth':
                self.model.load_state_dict(torch.load(args.resume,
                                                      self.device))

        self.post_prob = Full_Post_Prob(args.sigma,
                                        args.alpha,
                                        args.crop_size,
                                        args.downsample_ratio,
                                        args.background_ratio,
                                        args.use_background,
                                        self.device,
                                        add=self.add,
                                        minx=args.minx,
                                        ratio=args.ratio)
        self.criterion = Full_Cov_Gaussian_Loss(args.use_background,
                                                self.device,
                                                weight=self.args.weight,
                                                reg=args.reg)

        self.save_list = Save_Handle(max_num=args.max_model_num)
        self.test_flag = False
        self.best_mae = {}
        self.best_mse = {}
        self.best_epoch = {}
        for stage in ['val', 'test']:
            self.best_mae[stage] = np.inf
            self.best_mse[stage] = np.inf
            self.best_epoch[stage] = 0
Example #13
0
def main(args):
    BATCH_SIZE = args.batch_size
    LR = args.learning_rate
    EPOCH = args.epoch

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    use_gpu = torch.cuda.is_available()

    data_transforms = {
        transforms.Compose([
            transforms.Resize(320),
            transforms.CenterCrop(299),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    }
    transform = transforms.Compose([
        transforms.Resize(size=(227, 227)),
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),  # 将图片转换为Tensor,归一化至[0,1]
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = torchvision.datasets.ImageFolder(root=args.train_images,
                                                     transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True)

    # 从文件夹中读取validation数据
    validation_dataset = torchvision.datasets.ImageFolder(
        root=args.test_images, transform=transform)
    print(validation_dataset.class_to_idx)

    test_loader = torch.utils.data.DataLoader(validation_dataset,
                                              batch_size=BATCH_SIZE,
                                              shuffle=True)

    if args.model_name == "densenet":
        Net = densenet.DenseNet().to(device)
    if args.model_name == "alexnet":
        Net = alexnet.AlexNet().to(device)
    if args.model_name == "googlenet":
        Net = googlenet.GoogLeNet().to(device)
    if args.model_name == "mobilenet":
        Net = mobilenet.MobileNetV2().to(device)
    if args.model_name == "mnasnet":
        Net = mnasnet.mnasnet1_0().to(device)
    if args.model_name == "squeezenet":
        Net = squeezenet.SqueezeNet().to(device)
    if args.model_name == "resnet":
        Net = resnet.resnet50().to(device)
    if args.model_name == "vgg":
        Net = vgg.vgg19().to(device)
    if args.model_name == "shufflenetv2":
        Net = shufflenetv2.shufflenet_v2_x1_0().to(device)

    criterion = nn.CrossEntropyLoss()
    opti = torch.optim.Adam(Net.parameters(), lr=LR)

    if __name__ == '__main__':
        Accuracy_list = []
        Loss_list = []

        for epoch in range(EPOCH):
            sum_loss = 0.0
            correct1 = 0

            total1 = 0
            for i, (images, labels) in enumerate(train_loader):
                num_images = images.size(0)

                images = Variable(images.to(device))
                labels = Variable(labels.to(device))

                if args.model_name == 'googlenet':
                    out = Net(images)
                    out = out[0]
                else:
                    out = Net(images)
                _, predicted = torch.max(out.data, 1)

                total1 += labels.size(0)

                correct1 += (predicted == labels).sum().item()

                loss = criterion(out, labels)
                print(loss)
                opti.zero_grad()
                loss.backward()
                opti.step()

                # 每训练100个batch打印一次平均loss
                sum_loss += loss.item()
                if i % 10 == 9:
                    print('train loss [%d, %d] loss: %.03f' %
                          (epoch + 1, i + 1, sum_loss / 2000))
                    print("train acc %.03f" % (100.0 * correct1 / total1))
                    sum_loss = 0.0
            Accuracy_list.append(100.0 * correct1 / total1)
            print('accurary={}'.format(100.0 * correct1 / total1))
            Loss_list.append(loss.item())

        x1 = range(0, EPOCH)
        x2 = range(0, EPOCH)
        y1 = Accuracy_list
        y2 = Loss_list

        total_test = 0
        correct_test = 0
        for i, (images, labels) in enumerate(test_loader):
            start_time = time.time()
            print('time_start', start_time)
            num_images = images.size(0)
            print('num_images', num_images)
            images = Variable(images.to(device))
            labels = Variable(labels.to(device))
            print("GroundTruth", labels)
            if args.model_name == 'googlenet':
                out = Net(images)[0]
                out = out[0]
            else:
                out = Net(images)
            _, predicted = torch.max(out.data, 1)
            print("predicted", predicted)
            correct_test += (predicted == labels).sum().item()
            total_test += labels.size(0)
            print('time_usage', (time.time() - start_time) / args.batch_size)
        print('total_test', total_test)
        print('correct_test', correct_test)
        print('accurary={}'.format(100.0 * correct_test / total_test))

        plt.subplot(2, 1, 1)
        plt.plot(x1, y1, 'o-')
        plt.title('Train accuracy vs. epoches')
        plt.ylabel('Train accuracy')
        plt.subplot(2, 1, 2)
        plt.plot(x2, y2, '.-')
        plt.xlabel('Train loss vs. epoches')
        plt.ylabel('Train loss')
        # plt.savefig("accuracy_epoch" + str(EPOCH) + ".png")
        plt.savefig(args.output_dir + '/' + 'accuracy_epoch' + str(EPOCH) +
                    '.png')
        plt.show()
        torch.save(args.output_dir, args.model_name + '.pth')
Example #14
0
def get_model(class_num):
    if (MODEL_TYPE == 'alexnet'):
        model = alexnet.alexnet(pretrained=FINETUNE)
    elif (MODEL_TYPE == 'vgg'):
        if (MODEL_DEPTH_OR_VERSION == 11):
            model = vgg.vgg11(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 13):
            model = vgg.vgg13(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 16):
            model = vgg.vgg16(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 19):
            model = vgg.vgg19(pretrained=FINETUNE)
        else:
            print('Error : VGG should have depth of either [11, 13, 16, 19]')
            sys.exit(1)
    elif (MODEL_TYPE == 'squeezenet'):
        if (MODEL_DEPTH_OR_VERSION == 0 or MODEL_DEPTH_OR_VERSION == 'v0'):
            model = squeezenet.squeezenet1_0(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 1 or MODEL_DEPTH_OR_VERSION == 'v1'):
            model = squeezenet.squeezenet1_1(pretrained=FINETUNE)
        else:
            print('Error : Squeezenet should have version of either [0, 1]')
            sys.exit(1)
    elif (MODEL_TYPE == 'resnet'):
        if (MODEL_DEPTH_OR_VERSION == 18):
            model = resnet.resnet18(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 34):
            model = resnet.resnet34(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 50):
            model = resnet.resnet50(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 101):
            model = resnet.resnet101(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 152):
            model = resnet.resnet152(pretrained=FINETUNE)
        else:
            print(
                'Error : Resnet should have depth of either [18, 34, 50, 101, 152]'
            )
            sys.exit(1)
    elif (MODEL_TYPE == 'densenet'):
        if (MODEL_DEPTH_OR_VERSION == 121):
            model = densenet.densenet121(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 169):
            model = densenet.densenet169(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 161):
            model = densenet.densenet161(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 201):
            model = densenet.densenet201(pretrained=FINETUNE)
        else:
            print(
                'Error : Densenet should have depth of either [121, 169, 161, 201]'
            )
            sys.exit(1)
    elif (MODEL_TYPE == 'inception'):
        if (MODEL_DEPTH_OR_VERSION == 3 or MODEL_DEPTH_OR_VERSION == 'v3'):
            model = inception.inception_v3(pretrained=FINETUNE)
        else:
            print('Error : Inception should have version of either [3, ]')
            sys.exit(1)
    else:
        print(
            'Error : Network should be either [alexnet / squeezenet / vgg / resnet / densenet / inception]'
        )
        sys.exit(1)

    if (MODEL_TYPE == 'alexnet' or MODEL_TYPE == 'vgg'):
        num_ftrs = model.classifier[6].in_features
        feature_model = list(model.classifier.children())
        feature_model.pop()
        feature_model.append(nn.Linear(num_ftrs, class_num))
        model.classifier = nn.Sequential(*feature_model)
    elif (MODEL_TYPE == 'resnet' or MODEL_TYPE == 'inception'):
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, class_num)
    elif (MODEL_TYPE == 'densenet'):
        num_ftrs = model.classifier.in_features
        model.classifier = nn.Linear(num_ftrs, class_num)

    return model
Example #15
0
def main():
    global args, best_err1
    args = parser.parse_args()

    # TensorBoard configure
    if args.tensorboard:
        configure('%s_checkpoints/%s'%(args.dataset, args.expname))

    # CUDA
    os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_ids)
    if torch.cuda.is_available():
        cudnn.benchmark = True  # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        kwargs = {'num_workers': 2, 'pin_memory': True}
    else:
        kwargs = {'num_workers': 2}

    # Data loading code
    if args.dataset == 'cifar10':
        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                         std=[0.2023, 0.1994, 0.2010])
    elif args.dataset == 'cifar100':
        normalize = transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
                                         std=[0.2634, 0.2528, 0.2719])
    elif args.dataset == 'cub':
        normalize = transforms.Normalize(mean=[0.4862, 0.4973, 0.4293],
                                         std=[0.2230, 0.2185, 0.2472])
    elif args.dataset == 'webvision':
        normalize = transforms.Normalize(mean=[0.49274242, 0.46481857, 0.41779366],
                                         std=[0.26831809, 0.26145372, 0.27042758])
    else:
        raise Exception('Unknown dataset: {}'.format(args.dataset))

    # Transforms
    if args.augment:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(args.train_image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(args.train_image_size),
            transforms.ToTensor(),
            normalize,
        ])
    val_transform = transforms.Compose([
        transforms.Resize(args.test_image_size),
        transforms.CenterCrop(args.test_crop_image_size),
        transforms.ToTensor(),
        normalize
    ])

    # Datasets
    num_classes = 10    # default 10 classes
    if args.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10('./data/', train=True, download=True, transform=train_transform)
        val_dataset = datasets.CIFAR10('./data/', train=False, download=True, transform=val_transform)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100('./data/', train=True, download=True, transform=train_transform)
        val_dataset = datasets.CIFAR100('./data/', train=False, download=True, transform=val_transform)
        num_classes = 100
    elif args.dataset == 'cub':
        train_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/DuAngAng/datasets/CUB-200-2011/train/',
                                             transform=train_transform)
        val_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/DuAngAng/datasets/CUB-200-2011/test/',
                                           transform=val_transform)
        num_classes = 200
    elif args.dataset == 'webvision':
        train_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/LiuJing/WebVision/info/train',
                                             transform=train_transform)
        val_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/LiuJing/WebVision/info/val',
                                           transform=val_transform)
        num_classes = 1000
    else:
        raise Exception('Unknown dataset: {}'.format(args.dataset))

    # Data Loader
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, **kwargs)

    # Create model
    if args.model == 'AlexNet':
        model = alexnet(pretrained=False, num_classes=num_classes)
    elif args.model == 'VGG':
        use_batch_normalization = True  # default use Batch Normalization
        if use_batch_normalization:
            if args.depth == 11:
                model = vgg11_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 13:
                model = vgg13_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 16:
                model = vgg16_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 19:
                model = vgg19_bn(pretrained=False, num_classes=num_classes)
            else:
                raise Exception('Unsupport VGG detph: {}, optional depths: 11, 13, 16 or 19'.format(args.depth))
        else:
            if args.depth == 11:
                model = vgg11(pretrained=False, num_classes=num_classes)
            elif args.depth == 13:
                model = vgg13(pretrained=False, num_classes=num_classes)
            elif args.depth == 16:
                model = vgg16(pretrained=False, num_classes=num_classes)
            elif args.depth == 19:
                model = vgg19(pretrained=False, num_classes=num_classes)
            else:
                raise Exception('Unsupport VGG detph: {}, optional depths: 11, 13, 16 or 19'.format(args.depth))
    elif args.model == 'Inception':
        model = inception_v3(pretrained=False, num_classes=num_classes)
    elif args.model == 'ResNet':
        if args.depth == 18:
            model = resnet18(pretrained=False, num_classes=num_classes)
        elif args.depth == 34:
            model = resnet34(pretrained=False, num_classes=num_classes)
        elif args.depth == 50:
            model = resnet50(pretrained=False, num_classes=num_classes)
        elif args.depth == 101:
            model = resnet101(pretrained=False, num_classes=num_classes)
        elif args.depth == 152:
            model = resnet152(pretrained=False, num_classes=num_classes)
        else:
            raise Exception('Unsupport ResNet detph: {}, optional depths: 18, 34, 50, 101 or 152'.format(args.depth))
    elif args.model == 'MPN-COV-ResNet':
        if args.depth == 18:
            model = mpn_cov_resnet18(pretrained=False, num_classes=num_classes)
        elif args.depth == 34:
            model = mpn_cov_resnet34(pretrained=False, num_classes=num_classes)
        elif args.depth == 50:
            model = mpn_cov_resnet50(pretrained=False, num_classes=num_classes)
        elif args.depth == 101:
            model = mpn_cov_resnet101(pretrained=False, num_classes=num_classes)
        elif args.depth == 152:
            model = mpn_cov_resnet152(pretrained=False, num_classes=num_classes)
        else:
            raise Exception('Unsupport MPN-COV-ResNet detph: {}, optional depths: 18, 34, 50, 101 or 152'.format(args.depth))
    else:
        raise Exception('Unsupport model'.format(args.model))

    # Get the number of model parameters
    print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))

    if torch.cuda.is_available():
        model = model.cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("==> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_err1 = checkpoint['best_err1']
            model.load_state_dict(checkpoint['state_dict'])
            print("==> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("==> no checkpoint found at '{}'".format(args.resume))

    print(model)

    # Define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    if torch.cuda.is_available():
        criterion = criterion.cuda()
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # Train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # Evaluate on validation set
        err1 = validate(val_loader, model, criterion, epoch)

        # Remember best err1 and save checkpoint
        is_best = (err1 <= best_err1)
        best_err1 = min(err1, best_err1)
        print("Current best accuracy (error):", best_err1)
        save_checkpoint({
            'epoch': epoch+1,
            'state_dict': model.state_dict(),
            'best_err1': best_err1,
        }, is_best)

    print("Best accuracy (error):", best_err1)
Example #16
0
def get_model(args):

    assert args.model in [
        'derpnet', 'alexnet', 'resnet', 'vgg', 'vgg_attn', 'inception'
    ]

    if args.model == 'alexnet':
        model = alexnet.alexnet(pretrained=args.pretrained,
                                n_channels=args.n_channels,
                                num_classes=args.n_classes)
    elif args.model == 'inception':
        model = inception.inception_v3(pretrained=args.pretrained,
                                       aux_logits=False,
                                       progress=True,
                                       num_classes=args.n_classes)
    elif args.model == 'vgg':
        assert args.model_depth in [11, 13, 16, 19]

        if args.model_depth == 11:
            model = vgg.vgg11_bn(pretrained=args.pretrained,
                                 progress=True,
                                 num_classes=args.n_classes)
        if args.model_depth == 13:
            model = vgg.vgg13_bn(pretrained=args.pretrained,
                                 progress=True,
                                 num_classes=args.n_classes)
        if args.model_depth == 16:
            model = vgg.vgg16_bn(pretrained=args.pretrained,
                                 progress=True,
                                 num_classes=args.n_classes)
        if args.model_depth == 19:
            model = vgg.vgg19(pretrained=args.pretrained,
                              progress=True,
                              num_classes=args.n_classes)

    elif args.model == 'vgg_attn':
        assert args.model_depth in [11, 13, 16, 19]

        if args.model_depth == 11:
            model = vgg_attn.vgg11_bn(pretrained=args.pretrained,
                                      progress=True,
                                      num_classes=args.n_classes)
        if args.model_depth == 13:
            model = vgg_attn.vgg11_bn(pretrained=args.pretrained,
                                      progress=True,
                                      num_classes=args.n_classes)
        if args.model_depth == 16:
            model = vgg_attn.vgg11_bn(pretrained=args.pretrained,
                                      progress=True,
                                      num_classes=args.n_classes)
        if args.model_depth == 19:
            model = vgg_attn.vgg11_bn(pretrained=args.pretrained,
                                      progress=True,
                                      num_classes=args.n_classes)

    elif args.model == 'derpnet':
        model = derp_net.Net(n_channels=args.n_channels,
                             num_classes=args.n_classes)

    elif args.model == 'resnet':
        assert args.model_depth in [10, 18, 34, 50, 101, 152, 200]

        if args.model_depth == 10:
            model = resnet.resnet10(pretrained=args.pretrained,
                                    num_classes=args.n_classes)
        elif args.model_depth == 18:
            model = resnet.resnet18(pretrained=args.pretrained,
                                    num_classes=args.n_classes)
        elif args.model_depth == 34:
            model = resnet.resnet34(pretrained=args.pretrained,
                                    num_classes=args.n_classes)
        elif args.model_depth == 50:
            model = resnet.resnet50(pretrained=args.pretrained,
                                    num_classes=args.n_classes)
        elif args.model_depth == 101:
            model = resnet.resnet101(pretrained=args.pretrained,
                                     num_classes=args.n_classes)
        elif args.model_depth == 152:
            model = resnet.resnet152(pretrained=args.pretrained,
                                     num_classes=args.n_classes)
        elif args.model_depth == 200:
            model = resnet.resnet200(pretrained=args.pretrained,
                                     num_classes=args.n_classes)

    if args.pretrained and args.pretrain_path and not args.model == 'alexnet' and not args.model == 'vgg' and not args.model == 'resnet':

        print('loading pretrained model {}'.format(args.pretrain_path))
        pretrain = torch.load(args.pretrain_path)
        assert args.arch == pretrain['arch']

        # here all the magic happens: need to pick the parameters which will be adjusted during training
        # the rest of the params will be frozen
        pretrain_dict = {
            key[7:]: value
            for key, value in pretrain['state_dict'].items()
            if key[7:9] != 'fc'
        }
        from collections import OrderedDict
        pretrain_dict = OrderedDict(pretrain_dict)

        # https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance
        import types
        model.load_state_dict = types.MethodType(load_my_state_dict, model)

        old_dict = copy.deepcopy(
            model.state_dict())  # normal copy() just gives a reference
        model.load_state_dict(pretrain_dict)
        new_dict = model.state_dict()

        num_features = model.fc.in_features
        if args.model == 'densenet':
            model.classifier = nn.Linear(num_features, args.n_classes)
        else:
            #model.fc = nn.Sequential(nn.Linear(num_features, 1028), nn.ReLU(), nn.Dropout(0.5), nn.Linear(1028, args.n_finetune_classes))
            model.fc = nn.Linear(num_features, args.n_classes)

        # parameters = get_fine_tuning_parameters(model, args.ft_begin_index)
        parameters = model.parameters()  # fine-tunining EVERYTHIIIIIANG
        # parameters = model.fc.parameters()  # fine-tunining ONLY FC layer
        return model, parameters

    return model, model.parameters()
Example #17
0
def get_model(args, model_path=None):
    """

    :param args: super arguments
    :param model_path: if not None, load already trained model parameters.
    :return: model
    """
    if args.scratch:  # train model from scratch
        pretrained = False
        model_dir = None
        print("=> Loading model '{}' from scratch...".format(args.model))
    else:  # train model with pretrained model
        pretrained = True
        model_dir = os.path.join(args.root_path, args.pretrained_models_path)
        print("=> Loading pretrained model '{}'...".format(args.model))

    if args.model.startswith('resnet'):

        if args.model == 'resnet18':
            model = resnet18(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet34':
            model = resnet34(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet50':
            model = resnet50(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet101':
            model = resnet101(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet152':
            model = resnet152(pretrained=pretrained, model_dir=model_dir)

        model.fc = nn.Linear(model.fc.in_features, args.num_classes)

    elif args.model.startswith('vgg'):
        if args.model == 'vgg11':
            model = vgg11(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg11_bn':
            model = vgg11_bn(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg13':
            model = vgg13(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg13_bn':
            model = vgg13_bn(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg16':
            model = vgg16(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg16_bn':
            model = vgg16_bn(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg19':
            model = vgg19(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg19_bn':
            model = vgg19_bn(pretrained=pretrained, model_dir=model_dir)

        model.classifier[6] = nn.Linear(model.classifier[6].in_features, args.num_classes)

    elif args.model == 'alexnet':
        model = alexnet(pretrained=pretrained, model_dir=model_dir)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, args.num_classes)

    # Load already trained model parameters and go on training
    if model_path is not None:
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model'])

    return model
Example #18
0
    elif args.arch == 'mobilenet_v1':
        model = mobilenet_v1(num_classes=num_classes)

    elif args.arch == 'mobilenet_v2':
        model = mobilenet_v2(num_classes=num_classes)

    elif args.arch == 'resnet18':
        model = resnet18(num_classes=num_classes)

    elif args.arch == 'resnet50':
        model = resnet50(num_classes=num_classes)

    elif args.arch == 'resnet152':
        model = resnet152(num_classes=num_classes)

    elif args.arch == 'resnet164':
        model = resnet_164(num_classes=num_classes)

    elif args.arch == 'vgg11':
        model = vgg11(num_classes=num_classes)

    elif args.arch == 'vgg19':
        model = vgg19(num_classes=num_classes)

    else:
        print('No such model.')
        sys.exit()

    count_model_param_flops(model, input_res=input_res)
if __name__ == '__main__':
    args = parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device.strip()  # set vis gpu

    datasets = Crowd(os.path.join(args.data_dir, 'test'),
                     512,
                     8,
                     is_gray=False,
                     method='val')
    dataloader = torch.utils.data.DataLoader(datasets,
                                             1,
                                             shuffle=False,
                                             num_workers=8,
                                             pin_memory=False)
    model = vgg19()
    device = torch.device('cuda')
    model.to(device)
    model.load_state_dict(
        torch.load(os.path.join(args.save_dir, 'best_model.pth'), device))
    epoch_minus = []

    for inputs, count, name in dataloader:
        inputs = inputs.to(device)
        assert inputs.size(0) == 1, 'the batch size should equal to 1'
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            temp_minu = count[0].item() - torch.sum(outputs).item()
            print(name, temp_minu, count[0].item(), torch.sum(outputs).item())
            epoch_minus.append(temp_minu)
Example #20
0
def get_network(args):
    """ return given network
    """
    if args.task == 'cifar10':
        nclass = 10
    elif args.task == 'cifar100':
        nclass = 100
    #Yang added none bn vggs
    if args.net == 'vgg16':
        from models.vgg import vgg16
        net = vgg16(num_classes=nclass)
    elif args.net == 'vgg13':
        from models.vgg import vgg13
        net = vgg13(num_classes=nclass)
    elif args.net == 'vgg11':
        from models.vgg import vgg11
        net = vgg11(num_classes=nclass)
    elif args.net == 'vgg19':
        from models.vgg import vgg19
        net = vgg19(num_classes=nclass)

    elif args.net == 'vgg16bn':
        from models.vgg import vgg16_bn
        net = vgg16_bn(num_classes=nclass)
    elif args.net == 'vgg13bn':
        from models.vgg import vgg13_bn
        net = vgg13_bn(num_classes=nclass)
    elif args.net == 'vgg11bn':
        from models.vgg import vgg11_bn
        net = vgg11_bn(num_classes=nclass)
    elif args.net == 'vgg19bn':
        from models.vgg import vgg19_bn
        net = vgg19_bn(num_classes=nclass)

    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet(num_classes=nclass)
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception(num_classes=nclass)
    elif args.net == 'scnet':
        from models.sphereconvnet import sphereconvnet
        net = sphereconvnet(num_classes=nclass)
    elif args.net == 'sphereresnet18':
        from models.sphereconvnet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'sphereresnet32':
        from models.sphereconvnet import sphereresnet32
        net = sphereresnet32(num_classes=nclass)
    elif args.net == 'plainresnet32':
        from models.sphereconvnet import plainresnet32
        net = plainresnet32(num_classes=nclass)
    elif args.net == 'ynet18':
        from models.ynet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'ynet34':
        from models.ynet import resnet34
        net = resnet34(num_classes=nclass)
    elif args.net == 'ynet50':
        from models.ynet import resnet50
        net = resnet50(num_classes=nclass)
    elif args.net == 'ynet101':
        from models.ynet import resnet101
        net = resnet101(num_classes=nclass)
    elif args.net == 'ynet152':
        from models.ynet import resnet152
        net = resnet152(num_classes=nclass)

    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34(num_classes=nclass)
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50(num_classes=nclass)
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101(num_classes=nclass)
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152(num_classes=nclass)
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18(num_classes=nclass)
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34(num_classes=nclass)
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50(num_classes=nclass)
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101(num_classes=nclass)
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152(num_classes=nclass)
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50(num_classes=nclass)
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101(num_classes=nclass)
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152(num_classes=nclass)
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet(num_classes=nclass)
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2(num_classes=nclass)
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet(num_classes=nclass)
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18(num_classes=nclass)
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34(num_classes=nclass)
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50(num_classes=nclass)
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101(num_classes=nclass)
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152(num_classes=nclass)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu:  #use_gpu
        net = net.cuda()

    return net