Exemplo n.º 1
0
    def __init__(self,
                 roots,
                 transform=None,
                 target_transform=None,
                 loader=default_loader):
        assert isinstance(roots, (tuple, list))
        self.classes_list, self.class_to_idx_list, self.imgs_list = [], [], []
        for root in roots:
            classes, class_to_idx = find_classes(root)
            imgs = make_dataset(root, class_to_idx)
            if len(imgs) == 0:
                raise (RuntimeError("Found 0 images in subfolders of: " +
                                    root + "\n"
                                    "Supported image extensions are: " +
                                    ",".join(IMG_EXTENSIONS)))

            # add them to the list
            self.classes_list.append(classes)
            self.class_to_idx_list.append(class_to_idx)
            self.imgs_list.append(imgs)

        # sanity check that we have the same number of samples
        num_imgs = len(self.imgs_list[0])
        for imgs in self.imgs_list:
            assert len(imgs) == num_imgs

        self.roots = roots
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
Exemplo n.º 2
0
def read_all_images(root, num_workers=4):
    classes, class_to_idx = find_classes(root)
    dataset = make_dataset(root, class_to_idx)
    if len(dataset) == 0:
        raise (RuntimeError("Found 0 images in subfolders of: " + root + "\n" +
                            "Supported image extensions are: " +
                            ",".join(IMG_EXTENSIONS)))

    num_images = len(dataset)
    paths = [dataset[i_image][0] for i_image in range(num_images)]

    print("Reading {0} images with {1} workers".format(num_images,
                                                       num_workers))
    if num_workers > 1:
        images = parallel_process(paths,
                                  read_image_for_pytorch,
                                  n_jobs=num_workers)
    else:
        images = []
        for p in tqdm(paths):
            images.append(read_image_for_pytorch(p))

    image_cache = {}
    for i, image in enumerate(images):
        path, target = dataset[i]
        image_cache[path] = image
    return image_cache
Exemplo n.º 3
0
def folder(root,
           num_workers=2,
           batch_size=64,
           img_size=224,
           sample_per_class=-1,
           data_augmentation=False):
    base_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    if data_augmentation:
        data_transform = transforms.Compose([
            transforms.RandomResizedCrop(img_size, scale=(0.25, 1)),
            transforms.RandomHorizontalFlip(), base_transform
        ])
    else:
        data_transform = base_transform

    data = datasets.ImageFolder(root=root, transform=data_transform)
    if sample_per_class == -1:
        data_sampler = RandomSampler(data)
    else:
        data_sampler = BalancedSubsetRandomSampler(data, sample_per_class,
                                                   len(find_classes(root)[0]))

    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=batch_size,
                                              sampler=data_sampler,
                                              num_workers=num_workers)

    return data_loader
Exemplo n.º 4
0
def train_test_split_for_dir(root_path: Path,
                             test_size: float,
                             random_state: int = 42):
    """
    torchvision.datasets.ImageFolder 形式のディレクトリ構造になっているデータセットを train / test に分割し、
    train, testそれぞれを `root_path` と同じ階層に `train/` `val/` して保存する。

    TODO: この関数内でTrain, Test用のImageFolderを作成する方が良いのか考える。
    """
    if not root_path.exists():
        raise FileNotFoundError
    elif not (0 <= test_size <= 1.0):
        raise ValueError

    classes, class_to_idx = find_classes(root_path)
    dataset = make_dataset(root_path, class_to_idx, IMG_EXTENSIONS)
    train, val = train_test_split(dataset,
                                  test_size=test_size,
                                  shuffle=True,
                                  random_state=random_state)

    split_dataset = {'train': train, 'val': val}

    dst_path_root = root_path.parent
    for set_ in ['train', 'val']:
        for file_path, class_ in tqdm(split_dataset[set_], desc=set_):
            file_path = Path(file_path)
            dst_dir = dst_path_root / set_ / str(class_)
            dst_dir.mkdir(exist_ok=True, parents=True)
            shutil.copy(file_path, dst_dir / file_path.name)
Exemplo n.º 5
0
    def __init__(
            self,
            root,
            test_fun,
            num_images,
            extensions=IMG_EXTENSIONS,
            transform=None,
            target_transform=None,
            loader=default_loader):
        classes, class_to_idx = find_classes(root)
        imgs = make_dataset(
            root,
            class_to_idx,
            extensions,
            test_fun,
            num_images)
        if len(imgs) == 0:
            raise(
                RuntimeError(
                    "Found 0 images in subfolders of: " +
                    root +
                    "\n"
                    "Supported image extensions are: " +
                    ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.samples = imgs
        self.extensions = extensions
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        self.test_fun = test_fun
        self.num_images = num_images
Exemplo n.º 6
0
    def __init__(self, root, loader, extensions, transform=None, target_transform=None):
        classes, class_to_idx = find_classes(root)
        samples_cache_path = os.path.join(root, 'samples.pickle')
        if os.path.exists(samples_cache_path):
            with open(samples_cache_path, 'rb') as rf:
                samples = pickle.load(rf)
            print('=> read {} samples from cache: {}'.format(len(samples), samples_cache_path))
        else:
            samples = make_dataset(root, class_to_idx, extensions)
            if os.access(root, os.W_OK):
                print('=> caching {} samples to: {}'.format(len(samples), samples_cache_path))
                with open(samples_cache_path, 'wb') as wf:
                    pickle.dump(samples, wf)
        if len(samples) == 0:
            raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n"
                                                                            "Supported extensions are: " + ",".join(
                extensions)))

        self.root = root
        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples

        self.transform = transform
        self.target_transform = target_transform
    def __init__(self,
                 root,
                 transform=None,
                 target_transform=None,
                 loader=default_loader,
                 n_frames_apart=1,
                 download=False):
        self.root = root
        if download:
            self.download()

        classes, class_to_idx = find_classes(root)
        imgs, actions = make_dataset(root, class_to_idx)
        if len(imgs) == 0:
            raise (RuntimeError("Found 0 images in subfolders of: " + root +
                                "\n"
                                "Supported image extensions are: " +
                                ",".join(IMG_EXTENSIONS)))
        resets = 1. - actions[:, -1]
        assert len(imgs) == len(resets)

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        img_pairs = make_pair(imgs, resets, n_frames_apart, self._get_image,
                              self.root)
        self.img_pairs = img_pairs
Exemplo n.º 8
0
    def __init__(self,
                 root,
                 transform=None,
                 target_transform=None,
                 loader=default_loader,
                 retun_idx=False):
        classes, class_to_idx = find_classes(root)
        IMG_EXTENSIONS = [
            '.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'
        ]

        try:
            imgs = make_dataset(root, class_to_idx, IMG_EXTENSIONS)
        except:
            imgs = make_dataset(root, class_to_idx)

        if len(imgs) == 0:
            raise (RuntimeError("Found 0 images in subfolders of: " + root +
                                "\n"
                                "Supported image extensions are: " +
                                ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        self.retun_idx = retun_idx
Exemplo n.º 9
0
    def __init__(
        self, *roots, transforms=None, target_transforms=None, loader=default_loader
    ):
        classes_ = []
        class_to_idx_ = []
        samples_ = []
        for root in roots:
            classes, class_to_idx = find_classes(root)
            samples = make_dataset(root, class_to_idx, IMG_EXTENSIONS)
            if len(samples) == 0:
                raise (
                    RuntimeError(
                        "Found 0 files in subfolders of: " + root + "\n"
                        "Supported extensions are: " + ",".join(IMG_EXTENSIONS)
                    )
                )
            classes_.append(classes)
            class_to_idx_.append(class_to_idx_)
            samples_.append(samples)
            if len(samples_[0]) != len(samples):
                raise ValueError(
                    "Dataset folders must have the same number of samples."
                )
            if len(classes_[0]) != len(classes):
                raise ValueError(
                    "Dataset folders must have the same number of classes."
                )
        super().__init__(roots, samples_, transforms, target_transforms)

        self.loader = loader
        self.extensions = IMG_EXTENSIONS

        self.classes = classes_
        self.class_to_idx = class_to_idx_
Exemplo n.º 10
0
def load_format_paths(folder_path, extension):

    classes, class_to_idx = find_classes(folder_path)
    samples = make_dataset(folder_path, class_to_idx, [extension])
    paths = np.array([s[0] for s in samples])
    classes = np.array([int(s[1]) for s in samples])

    return paths, classes
Exemplo n.º 11
0
 def __init__(self, root_path, train_dir, valid_dir):
     self.classes, self.class_to_idx = find_classes(root_path / 'train')
     train_samples = make_dataset(root_path / 'train',
                                  self.class_to_idx,
                                  extensions=IMG_EXTENSIONS)
     valid_samples = make_dataset(root_path / 'valid',
                                  self.class_to_idx,
                                  extensions=IMG_EXTENSIONS)
     self.samples = train_samples + valid_samples
Exemplo n.º 12
0
def read_labels(dir_path):
    """Reads labels and label indices from directory
      Args:
        dir_path - path to data directory
      Returns:
        tuple of -
          classes - array of labels
          class_to_idx - dictionary of indices keyed by classes
    """
    return find_classes(dir_path)
Exemplo n.º 13
0
def create_image_to_label(directory, batch_size=16, ahead=4):
    ed = expdir.ExperimentDirectory(directory)
    info = ed.load_info()

    print info.dataset
    if 'broden' in info.dataset:
        ds = loadseg.SegmentationData(info.dataset)
        categories = ds.category_names()
        shape = (ds.size(), len(ds.label))

        pf = loadseg.SegmentationPrefetcher(ds,
                                            categories=categories,
                                            once=True,
                                            batch_size=batch_size,
                                            ahead=ahead,
                                            thread=False)

        image_to_label = np.zeros(shape, dtype='int32')

        batch_count = 0
        for batch in pf.batches():
            if batch_count % 100 == 0:
                print('Processing batch %d ...' % batch_count)
            for rec in batch:
                image_index = rec['i']
                for cat in categories:
                    if ((type(rec[cat]) is np.ndarray and rec[cat].size > 0)
                            or type(rec[cat]) is list and len(rec[cat]) > 0):
                        image_to_label[image_index][np.unique(rec[cat])] = True
            batch_count += 1
    elif 'imagenet' in info.dataset or 'ILSVRC' in info.dataset:
        classes, class_to_idx = find_classes(info.dataset)
        imgs = make_dataset(info.dataset, class_to_idx)
        _, labels = zip(*imgs)
        labels = np.array(labels)

        L = 1000
        shape = (len(labels), L)

        image_to_label = np.zeros(shape)

        for i in range(L):
            image_to_label[labels == i, i] = 1
    else:
        assert (False)

    mmap = ed.open_mmap(part='image_to_label',
                        mode='w+',
                        dtype=bool,
                        shape=shape)
    mmap[:] = image_to_label[:]
    ed.finish_mmap(mmap)
    f = ed.mmap_filename(part='image_to_label')

    print('Finished and saved index_to_label at %s' % f)
Exemplo n.º 14
0
    def __init__(self, cfg, data_dir=None, transform=None, labeled=True):
        self.cfg = cfg
        self.transform = transform
        self.data_dir = data_dir
        self.labeled = labeled

        if labeled:
            self.classes, self.class_to_idx = find_classes(self.data_dir)
            self.int_to_class = dict(zip(range(len(self.classes)), self.classes))
            self.imgs = make_dataset(self.data_dir, self.class_to_idx, ['jpg','png'])
        else:
            self.imgs = utils.get_images(self.data_dir, ['jpg', 'png'])
Exemplo n.º 15
0
 def __init__(self, root, loader=None, transform=None):
     assert os.path.exists(root)
     self.classes, class_to_idx = find_classes(os.path.join(root, 'train'))
     self.image_paths = []
     path = os.path.join(root, 'test', 'images')
     for p, _, fnames in sorted(os.walk(path)):
         for fname in sorted(fnames):
             if has_file_allowed_extension(fname, IMG_EXTENSIONS):
                 path = os.path.join(p, fname)
                 self.image_paths.append(path)
     self.loader = pil_loader if loader is None else loader
     self.transform = transform
Exemplo n.º 16
0
 def _extract_dataset(self):
     if self._check_Integrity():
         logging.info(f"Load meta info into {self.meta_file}")
         self.classes, self.class_to_idx, self.samples = pickle.load(
             open(os.path.join(self.root, self.meta_file), 'rb'))
     else:
         self.classes, self.class_to_idx = find_classes(self.root)
         self.samples = make_dataset(self.root, self.class_to_idx,
                                     self.extensions)
         pickle.dump((self.classes, self.class_to_idx, self.samples),
                     open(os.path.join(self.root, self.meta_file), 'wb'))
         logging.info(f"Processed dataset meta info into {self.meta_file}")
     return self.samples
Exemplo n.º 17
0
    def __init__(self,
                 root,
                 loader,
                 extensions,
                 transform=None,
                 target_transform=None,
                 class_map=None):

        if class_map is None:
            classes, class_to_idx = find_classes(root)
        else:
            if os.path.exists(class_map):
                with open(class_map) as f:
                    c_map = json.load(f)
                classes = [c for c in c_map]
                class_to_idx = c_map
            else:
                classes, class_to_idx = find_classes(root)
                with open(class_map, "w") as f:
                    json.dump(class_to_idx, f)

        samples = make_dataset(root, class_to_idx, extensions)
        if len(samples) == 0:
            raise (RuntimeError("Found 0 files in subfolders of: " + root +
                                "\n"
                                "Supported extensions are: " +
                                ",".join(extensions)))

        self.root = root
        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples

        self.transform = transform
        self.target_transform = target_transform
Exemplo n.º 18
0
    def __init__(self, root, transform):
        self.classes, self.class_to_idx = find_classes(root)
        self.samples = make_dataset(root,
                                    self.class_to_idx,
                                    extensions=IMG_EXTENSIONS)  # path, target
        self.loader = default_loader
        self.data = []
        self.targets = []
        self.transform = transform

        for path, target in self.samples:
            img = self.loader(path)
            self.data.append(img)
            self.targets.append(target)
Exemplo n.º 19
0
    def __init__(self, data_path, image_cache, do_random_flips=False,
                 normalization=transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))):
        classes, class_to_idx = find_classes(data_path)
        imgs = make_dataset(data_path, class_to_idx)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + data_path + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.root = data_path
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.normalization = normalization
        self.do_random_flips = do_random_flips
        self.image_cache = image_cache
Exemplo n.º 20
0
def gen_gradcam_image(model_parameter, data_dir, result_dir):
    cnn = utils.model_load(model_parameter)

    input_size = vis_utils.get_input_size(cnn.meta['base_model'])
    target_layer = vis_utils.get_conv_layer(cnn.meta['base_model'])
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    visualizer = Visualize(cnn,
                           preprocess,
                           target_layer,
                           num_classes=cnn.meta['num_classes'],
                           retainModel=False)

    result_dir = os.sep.join(["results", result_dir])
    if os.path.isdir(result_dir):
        if len(os.listdir(result_dir)) != 0:
            print("Result path exist and not empty, not generating")
            return
    else:
        os.mkdir(result_dir)

    from torchvision.datasets.folder import find_classes, make_dataset
    classes, class_to_idx = find_classes(data_dir)
    dataset = make_dataset(data_dir, class_to_idx)

    for class_name in os.listdir(data_dir):
        os.mkdir(os.sep.join([result_dir, class_name]))

    for img_path, idx in dataset:
        target_img_path = img_path.replace(data_dir + os.sep, "")
        print("Processing: ", target_img_path)

        img_pil = Image.open(img_path)
        img_pil = img_pil.resize((input_size, input_size))

        visualizer.input_image(img_pil)
        x = visualizer.get_prediction_output()
        x = F.softmax(Variable(x)).data

        for pred_c, idx in class_to_idx.items():
            gradcam = visualizer.get_gradcam_heatmap(idx)[0]
            target_img = target_img_path.replace(
                '.jpg', '_{}_{:.4f}.jpg'.format(pred_c, x[0][idx]))
            gradcam.save(os.sep.join([result_dir, target_img]))
Exemplo n.º 21
0
    def __init__(self,
                 root,
                 loader,
                 extensions,
                 transform=None,
                 target_transform=None):
        chunk_name = os.path.basename(root)
        chunk_dir = os.path.dirname(root)
        self._cache_path = os.path.join(chunk_dir,
                                        "{}.cache".format(chunk_name))

        if not os.path.exists(self._cache_path):
            logger.info("Create dataset for dir: {}".format(root))
            classes, class_to_idx = find_classes(root)
            samples = make_dataset(root, class_to_idx, extensions)
            if len(samples) == 0:
                raise (RuntimeError("Found 0 files in subfolders of: " + root +
                                    "\n"
                                    "Supported extensions are: " +
                                    ",".join(extensions)))

            logger.info("Save cache to: {}".format(self._cache_path))
            with open(self._cache_path, 'wb') as cache_file:
                pickle.dump(
                    dict(classes=classes,
                         class_to_idx=class_to_idx,
                         samples=samples), cache_file, pickle.HIGHEST_PROTOCOL)
        else:
            logger.info("Load cache from: {}".format(self._cache_path))
            with open(self._cache_path, 'rb') as cache_file:
                data = pickle.load(cache_file)
                classes = data['classes']
                class_to_idx = data['class_to_idx']
                samples = data['samples']

        logger.info("Dataset samples: {}".format(len(samples)))

        self.root = root
        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples

        self.transform = transform
        self.target_transform = target_transform
Exemplo n.º 22
0
    def __init__(self,
                 root,
                 transform=None,
                 target_transform=None,
                 loader=folder.default_loader):
        classes, class_to_idx = folder.find_classes(root)
        imgs = folder.make_dataset(root, class_to_idx, IMG_EXTENSIONS)
        if len(imgs) == 0:
            raise (RuntimeError("Found 0 images in subfolders of: " + root +
                                "\n"
                                "Supported image extensions are: " +
                                ",".join(folder.IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
Exemplo n.º 23
0
    def __init__(self,
                 data_path,
                 image_cache,
                 do_random_flips=False,
                 normalization=transforms.Normalize((0.5, 0.5, 0.5),
                                                    (0.5, 0.5, 0.5))):
        classes, class_to_idx = find_classes(data_path)
        imgs = make_dataset(data_path, class_to_idx)
        if len(imgs) == 0:
            raise (RuntimeError("Found 0 images in subfolders of: " +
                                data_path + "\n"
                                "Supported image extensions are: " +
                                ",".join(IMG_EXTENSIONS)))

        self.root = data_path
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.normalization = normalization
        self.do_random_flips = do_random_flips
        self.image_cache = image_cache
Exemplo n.º 24
0
def read_all_images(root, num_workers=4):
    classes, class_to_idx = find_classes(root)
    dataset = make_dataset(root, class_to_idx)
    if len(dataset) == 0:
        raise (RuntimeError("Found 0 images in subfolders of: " + root + "\n" +
                            "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

    num_images = len(dataset)
    paths = [dataset[i_image][0] for i_image in range(num_images)]

    print("Reading {0} images with {1} workers".format(num_images, num_workers))
    if num_workers > 1:
        images = parallel_process(paths, read_image_for_pytorch, n_jobs=num_workers)
    else:
        images = []
        for p in tqdm(paths):
            images.append(read_image_for_pytorch(p))

    image_cache = {}
    for i, image in enumerate(images):
        path, target = dataset[i]
        image_cache[path] = image
    return image_cache
Exemplo n.º 25
0
def index_imagenet(root, index_save_path):
    """
    Index a data folder, where samples are arranged in this way:

        root/class_x/xxx.ext
        root/class_x/xxy.ext
        root/class_x/xxz.ext

        root/class_y/123.ext
        root/class_y/nsdf3.ext
        root/class_y/asd932_.ext

    Args:
        root (string): Root directory path.
        index_save_path (string): Where to save index.

    Returns:
        index (dict) containing the following items:
            classes (list): List of the class names.
            class_to_idx (dict): Dict with items (class_name, class_index).
            samples (list): List of (sample path, class_index) tuples
    """
    classes, class_to_idx = find_classes(root)
    samples = make_dataset(root, class_to_idx, IMG_EXTENSIONS)
    if len(samples) == 0:
        raise (RuntimeError("Found 0 files in subfolders of: " + root +
                            "\nSupported extensions are: " +
                            ",".join(IMG_EXTENSIONS)))
    index = {
        'classes': classes,
        'class_to_idx': class_to_idx,
        'samples': samples
    }
    with open(index_save_path, 'w') as json_file:
        json.dump(index, json_file, indent=0)
    return index
Exemplo n.º 26
0
def gen_gradcam_numpy(model_parameter, data_dir):
    cnn = utils.model_load(model_parameter)

    input_size = vis_utils.get_input_size(cnn.meta['base_model'])
    target_layer = vis_utils.get_conv_layer(cnn.meta['base_model'])
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    visualizer = Visualize(cnn,
                           preprocess,
                           target_layer,
                           num_classes=cnn.meta['num_classes'],
                           retainModel=False)

    result = np.empty(0)

    from torchvision.datasets.folder import find_classes, make_dataset
    classes, class_to_idx = find_classes(data_dir)
    dataset = make_dataset(data_dir, class_to_idx)

    for img_path, idx in dataset:
        print("Processing: ", img_path.replace(data_dir + os.sep, ""))

        img_pil = Image.open(img_path)
        img_pil = img_pil.resize((input_size, input_size))

        visualizer.input_image(img_pil)

        gradcam = visualizer.get_gradcam_intensity(idx)
        gradcam = vis_utils.normalize(gradcam)
        result = np.append(result, gradcam)
        print(result.shape)

    np.save(os.sep.join(['results', model_parameter, 'gradcam.npy']), result)
Exemplo n.º 27
0
            img = transform(img)

        return img

    def __len__(self):
        return len(self.images)


if __name__ == '__main__':
    args = parser.parse_args()
    print(args)

    eps = 1e-6

    inception = load_patched_inception_v3()
    _, class2id = find_classes(args.img)
    total_class = len(class2id)

    if args.model == 'dcgan':
        from model import Generator

    elif args.model == 'resnet':
        from model_resnet import Generator

    generator = Generator(args.code, total_class).to(device)
    generator.load_state_dict(torch.load(args.checkpoint))
    generator.eval()

    fids = []

    for class_name, id in class2id.items():
Exemplo n.º 28
0
from torch.autograd import Variable
import utils.dataset as dataset
import torchvision.datasets.folder as folder
from utils.BasicTrainTest import *
import utils.models as models
import warnings
from time import time

warnings.filterwarnings('ignore')

use_gpu = torch.cuda.is_available()

root = './test/'
path = './test/沙发/10.jpg'

classes, class_to_idx = folder.find_classes('./test/')

idx_to_class = {}

for classname, idx in class_to_idx.items():
    idx_to_class[idx] = classname

checkpointPath = './checkpoint/squeezenet1_10.766.pth'
state = torch.load(checkpointPath)
model = models.squeezenet1_1(20, pretrained=False)
if use_gpu:
    model = model.cuda()
model.load_state_dict(state)

model.eval()
Exemplo n.º 29
0
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
    torch.cuda.set_device(args.gpu_id)

sample = [int(i) for i in args.sample_per_class.split(',')] * 2

if args.dataset not in datasets.available_datasets:
    # custom folder dataset
    import os
    train_data = os.path.join(args.dataset, 'train')
    test_data = os.path.join(args.dataset, 'val')

    from torchvision.datasets.folder import find_classes
    num_classes = len(find_classes(train_data)[0])
    train_loader = datasets.folder(train_data,
                                   batch_size=args.batch_size,
                                   data_augmentation=args.data_augmentation,
                                   sample_per_class=sample[0])
    test_loader = datasets.folder(test_data,
                                  batch_size=args.batch_size,
                                  data_augmentation=args.data_augmentation,
                                  sample_per_class=sample[1])
else:
    num_classes = datasets.num_classes[args.dataset]
    dataset = datasets.__dict__[args.dataset]
    train_loader, test_loader = dataset(
        batch_size=args.batch_size,
        data_augmentation=args.data_augmentation,
        download=False,