Exemplo n.º 1
0
def load_parameters(params_path):
    if not os.path.isfile(params_path):
        from nnabla.utils.download import download
        url = os.path.join("https://nnabla.org/pretrained-models/nnabla-examples/eval_metrics/lpips",
                           params_path.split("/")[-1])
        download(url, params_path, False)
    nn.load_parameters(params_path)
Exemplo n.º 2
0
    def _load_nnp(self, rel_name, rel_url):
        '''
            Args:
                rel_name: relative path to where downloaded nnp is saved.
                rel_url: relative url path to where nnp is downloaded from.

            '''
        from nnabla.utils.download import download
        path_nnp = os.path.join(get_model_home(),
                                'imagenet/{}'.format(rel_name))
        url = os.path.join(get_model_url_base(), 'imagenet/{}'.format(rel_url))
        logger.info('Downloading {} from {}'.format(rel_name, url))
        dir_nnp = os.path.dirname(path_nnp)
        if not os.path.isdir(dir_nnp):
            os.makedirs(dir_nnp)
        download(url, path_nnp, open_file=False, allow_overwrite=False)
        print('Loading {}.'.format(path_nnp))
        self.nnp = NnpLoader(path_nnp)
Exemplo n.º 3
0
    def __init__(self, train=True, shuffle=False, rng=None):
        super(Cifar10DataSource, self).__init__(shuffle=shuffle, rng=rng)

        self._train = train
        data_uri = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
        logger.info('Getting labeled data from {}.'.format(data_uri))
        r = download(data_uri)  # file object returned
        with tarfile.open(fileobj=r, mode="r:gz") as fpin:
            # Training data
            if train:
                images = []
                labels = []
                for member in fpin.getmembers():
                    if "data_batch" not in member.name:
                        continue
                    fp = fpin.extractfile(member)
                    data = np.load(fp, encoding="bytes", allow_pickle=True)
                    images.append(data[b"data"])
                    labels.append(data[b"labels"])
                self._size = 50000
                self._images = np.concatenate(images).reshape(
                    self._size, 3, 32, 32)
                self._labels = np.concatenate(labels).reshape(-1, 1)
            # Validation data
            else:
                for member in fpin.getmembers():
                    if "test_batch" not in member.name:
                        continue
                    fp = fpin.extractfile(member)
                    data = np.load(fp, encoding="bytes", allow_pickle=True)
                    images = data[b"data"]
                    labels = data[b"labels"]
                self._size = 10000
                self._images = images.reshape(self._size, 3, 32, 32)
                self._labels = np.array(labels).reshape(-1, 1)
        r.close()
        logger.info('Getting labeled data from {}.'.format(data_uri))

        self._size = self._labels.size
        self._variables = ('x', 'y')
        if rng is None:
            rng = np.random.RandomState()
        self.rng = rng

        # keep data paths
        self.data_history = queue.Queue(maxsize=128)

        self.reset()
Exemplo n.º 4
0
def load_parameters(params_path):
    if not os.path.isfile(params_path):
        from nnabla.utils.download import download
        url = "https://nnabla.org/pretrained-models/nnabla-examples/eval_metrics/inceptions/original_inception_v3.h5"
        download(url, params_path, False)
    nn.load_parameters(params_path)