def load_parameters(params_path): if not os.path.isfile(params_path): from nnabla.utils.download import download url = os.path.join("https://nnabla.org/pretrained-models/nnabla-examples/eval_metrics/lpips", params_path.split("/")[-1]) download(url, params_path, False) nn.load_parameters(params_path)
def _load_nnp(self, rel_name, rel_url): ''' Args: rel_name: relative path to where downloaded nnp is saved. rel_url: relative url path to where nnp is downloaded from. ''' from nnabla.utils.download import download path_nnp = os.path.join(get_model_home(), 'imagenet/{}'.format(rel_name)) url = os.path.join(get_model_url_base(), 'imagenet/{}'.format(rel_url)) logger.info('Downloading {} from {}'.format(rel_name, url)) dir_nnp = os.path.dirname(path_nnp) if not os.path.isdir(dir_nnp): os.makedirs(dir_nnp) download(url, path_nnp, open_file=False, allow_overwrite=False) print('Loading {}.'.format(path_nnp)) self.nnp = NnpLoader(path_nnp)
def __init__(self, train=True, shuffle=False, rng=None): super(Cifar10DataSource, self).__init__(shuffle=shuffle, rng=rng) self._train = train data_uri = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" logger.info('Getting labeled data from {}.'.format(data_uri)) r = download(data_uri) # file object returned with tarfile.open(fileobj=r, mode="r:gz") as fpin: # Training data if train: images = [] labels = [] for member in fpin.getmembers(): if "data_batch" not in member.name: continue fp = fpin.extractfile(member) data = np.load(fp, encoding="bytes", allow_pickle=True) images.append(data[b"data"]) labels.append(data[b"labels"]) self._size = 50000 self._images = np.concatenate(images).reshape( self._size, 3, 32, 32) self._labels = np.concatenate(labels).reshape(-1, 1) # Validation data else: for member in fpin.getmembers(): if "test_batch" not in member.name: continue fp = fpin.extractfile(member) data = np.load(fp, encoding="bytes", allow_pickle=True) images = data[b"data"] labels = data[b"labels"] self._size = 10000 self._images = images.reshape(self._size, 3, 32, 32) self._labels = np.array(labels).reshape(-1, 1) r.close() logger.info('Getting labeled data from {}.'.format(data_uri)) self._size = self._labels.size self._variables = ('x', 'y') if rng is None: rng = np.random.RandomState() self.rng = rng # keep data paths self.data_history = queue.Queue(maxsize=128) self.reset()
def load_parameters(params_path): if not os.path.isfile(params_path): from nnabla.utils.download import download url = "https://nnabla.org/pretrained-models/nnabla-examples/eval_metrics/inceptions/original_inception_v3.h5" download(url, params_path, False) nn.load_parameters(params_path)