Beispiel #1
0
def save_parameters(path):
    """Save all parameters into a file with the specified format.

    Currently hdf5 and protobuf formats are supported.

    Args:
      path : path or file object
    """
    _, ext = os.path.splitext(path)
    params = get_parameters(grad_only=False)
    if ext == '.h5':
        import h5py
        with h5py.File(path, 'w') as hd:
            params = get_parameters(grad_only=False)
            for i, (k, v) in enumerate(iteritems(params)):
                hd[k] = v.d
                hd[k].attrs['need_grad'] = v.need_grad
                # To preserve order of parameters
                hd[k].attrs['index'] = i
    elif ext == '.protobuf':
        proto = nnabla_pb2.NNablaProtoBuf()
        for variable_name, variable in params.items():
            parameter = proto.parameter.add()
            parameter.variable_name = variable_name
            parameter.shape.dim.extend(variable.shape)
            parameter.data.extend(numpy.array(variable.d).flatten().tolist())
            parameter.need_grad = variable.need_grad

        with open(path, "wb") as f:
            f.write(proto.SerializeToString())
    else:
        logger.critical('Only supported hdf5 or protobuf.')
        assert False
    logger.info("Parameter save ({}): {}".format(ext, path))
Beispiel #2
0
 def open(self, filename=None):
     if filename is None:
         filename = self._base_uri
     else:
         if self._file_type == 's3':
             filename = urljoin(self._base_uri.replace(
                 's3://', 'http://'), filename.replace('\\', '/')).replace('http://', 's3://')
         elif self._file_type == 'http':
             filename = urljoin(self._base_uri, filename.replace('\\', '/'))
         else:
             filename = os.path.abspath(os.path.join(os.path.dirname(
                 self._base_uri.replace('\\', '/')), filename.replace('\\', '/')))
     f = None
     if self._file_type == 's3':
         uri_header, uri_body = filename.split('://', 1)
         us = uri_body.split('/')
         bucketname = us.pop(0)
         key = '/'.join(us)
         logger.info('Opening {}'.format(key))
         f = StringIO(self._s3_bucket.Object(key).get()['Body'].read())
     elif self._file_type == 'http':
         f = request.urlopen(filename)
     else:
         f = open(filename, 'rb')
     yield f
     f.close()
Beispiel #3
0
def _create_dataset(uri, batch_size, shuffle, no_image_normalization, cache_dir, overwrite_cache, create_cache_explicitly, prepare_data_iterator):
    class Dataset:
        pass
    dataset = Dataset()
    dataset.uri = uri
    dataset.normalize = not no_image_normalization

    if prepare_data_iterator:
        if cache_dir == '':
            cache_dir = None
        if cache_dir and create_cache_explicitly:
            if not os.path.exists(cache_dir) or overwrite_cache:
                if not os.path.exists(cache_dir):
                    os.mkdir(cache_dir)
                logger.info('Creating cache data for "' + uri + '"')
                with data_iterator_csv_dataset(uri, batch_size, shuffle, normalize=False, cache_dir=cache_dir) as di:
                    index = 0
                    while index < di.size:
                        progress('', (1.0 * di.position) / di.size)
                        di.next()
                        index += batch_size
            dataset.data_iterator = (lambda: data_iterator_cache(
                cache_dir, batch_size, shuffle, normalize=dataset.normalize))
        elif not cache_dir or overwrite_cache or not os.path.exists(cache_dir):
            if cache_dir and not os.path.exists(cache_dir):
                os.mkdir(cache_dir)
            dataset.data_iterator = (lambda: data_iterator_csv_dataset(
                uri, batch_size, shuffle, normalize=dataset.normalize, cache_dir=cache_dir))
        else:
            dataset.data_iterator = (lambda: data_iterator_cache(
                cache_dir, batch_size, shuffle, normalize=dataset.normalize))
    else:
        dataset.data_iterator = None
    return dataset
Beispiel #4
0
    def __init__(self,
                 data_source,
                 batch_size,
                 rng=None,
                 epoch_begin_callbacks=[],
                 epoch_end_callbacks=[]):
        logger.info('Using DataIterator')
        if rng is None:
            rng = numpy.random.RandomState(313)
        self._rng = rng
        self._shape = None       # Only use with padding
        self._data_position = 0  # Only use with padding

        self._data_source = data_source
        self._variables = data_source.variables
        self._batch_size = batch_size
        self._epoch = -1

        self._epoch_end_callbacks = list(epoch_end_callbacks)
        self._epoch_begin_callbacks = list(epoch_begin_callbacks)

        self._size = data_source.size

        self._reset()
        self._closed = False
        atexit.register(self.close)
Beispiel #5
0
    def __init__(self, base_uri):
        self._base_uri = base_uri
        if base_uri[0:5].lower() == 's3://':
            self._file_type = 's3'
            uri_header, uri_body = self._base_uri.split('://', 1)
            us = uri_body.split('/')
            bucketname = us.pop(0)
            self._s3_base_key = '/'.join(us)
            logger.info('Creating session for S3 bucket {}'.format(bucketname))

            import boto3
            self._s3_bucket = boto3.session.Session().resource('s3').Bucket(bucketname)

        elif base_uri[0:7].lower() == 'http://' or base_uri[0:8].lower() == 'https://':
            self._file_type = 'http'
        else:
            self._file_type = 'file'
Beispiel #6
0
    def add(self, index, value):
        """Add a value to the series.

        Args:
            index (int): Index.
            value (float): Value.

        """
        self.buf.append(value)
        if (index - self.flush_at) < self.interval:
            return
        value = np.mean(self.buf)
        if self.verbose:
            logger.info("iter={} {{{}}}={}".format(index, self.name, value))
        if self.fd is not None:
            print("{} {:g}".format(index, value), file=self.fd)
        self.flush_at = index
        self.buf = []
Beispiel #7
0
def load_parameters(path, proto=None):
    """Load parameters from a file with the specified format.

    Args:
      path : path or file object
    """
    _, ext = os.path.splitext(path)
    if proto is None:
        proto = nnabla_pb2.NNablaProtoBuf()
    if ext == '.h5':
        import h5py
        with h5py.File(path, 'r') as hd:
            keys = []

            def _get_keys(name):
                ds = hd[name]
                if not isinstance(ds, h5py.Dataset):
                    # Group
                    return
                # To preserve order of parameters
                keys.append((ds.attrs.get('index', None), name))
            hd.visit(_get_keys)
            for _, key in sorted(keys):
                ds = hd[key]
                var = get_parameter_or_create(key, ds.shape,
                                              need_grad=ds.attrs['need_grad'])
                var.data.cast(ds.dtype)[...] = ds[...]
                parameter = proto.parameter.add()
                parameter.variable_name = key
                parameter.shape.dim.extend(var.shape)
                parameter.data.extend(numpy.array(var.d).flatten().tolist())
                parameter.need_grad = var.need_grad
    elif ext == '.protobuf':
        with open(path, 'rb') as f:
            proto.MergeFromString(f.read())
            for parameter in proto.parameter:
                var = get_parameter_or_create(
                    parameter.variable_name, parameter.shape.dim)
                param = numpy.reshape(parameter.data, parameter.shape.dim)
                var.d = param
                var.need_grad = parameter.need_grad
    logger.info("Parameter load ({}): {}".format(format, path))
    return proto
Beispiel #8
0
    def add(self, index, var):
        """Add a minibatch of images to the monitor.

        Args:
            index (int): Index.
            var (:obj:`~nnabla.Variable`, :obj:`~nnabla.NdArray`, or :obj:`~numpy.ndarray`):
                A minibatch of images with ``(N, ..., C, H, W)`` format.
                If C == 2, blue channel is appended with ones. If C > 3,
                the array will be sliced to remove C > 3 sub-array.

        """
        import nnabla as nn
        from scipy.misc import imsave
        if index != 0 and (index + 1) % self.interval != 0:
            return
        if isinstance(var, nn.Variable):
            data = var.d.copy()
        elif isinstance(var, nn.NdArray):
            data = var.data.copy()
        else:
            assert isinstance(var, np.ndarray)
            data = var.copy()
        assert data.ndim > 2
        channels = data.shape[-3]
        data = data.reshape(-1, *data.shape[-3:])
        data = data[:min(data.shape[0], self.num_images)]
        data = self.normalize_method(data)
        if channels > 3:
            data = data[:, :3]
        elif channels == 2:
            data = np.concatenate(
                [data, np.ones((data.shape[0], 1) + data.shape[-2:])], axis=1)
        path_tmpl = os.path.join(self.save_dir, '{:06d}-{}.png')
        for j in range(min(self.num_images, data.shape[0])):
            img = data[j].transpose(1, 2, 0)
            if img.shape[-1] == 1:
                img = img[..., 0]
            path = path_tmpl.format(index, '{:03d}'.format(j))
            imsave(path, img)
        if self.verbose:
            logger.info("iter={} {{{}}} are written to {}.".format(
                index, self.name, path_tmpl.format(index, '*')))
Beispiel #9
0
def test_data_iterator_csv_dataset(test_data_csv_png_10,
                                   test_data_csv_png_20,
                                   size,
                                   batch_size,
                                   shuffle,
                                   normalize,
                                   with_memory_cache,
                                   with_file_cache,
                                   with_context):

    nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size', '3')
    nnabla_config.set(
        'DATA_ITERATOR', 'data_source_buffer_max_size', '10000')
    nnabla_config.set(
        'DATA_ITERATOR', 'data_source_buffer_num_of_data', '9')

    if size == 10:
        csvfilename = test_data_csv_png_10
    elif size == 20:
        csvfilename = test_data_csv_png_20

    logger.info(csvfilename)

    if with_context:
        with data_iterator_csv_dataset(uri=csvfilename,
                                       batch_size=batch_size,
                                       shuffle=shuffle,
                                       normalize=normalize,
                                       with_memory_cache=with_memory_cache,
                                       with_file_cache=with_file_cache) as di:
            check_data_iterator_result(di, batch_size, shuffle, normalize)
    else:
        di = data_iterator_csv_dataset(uri=csvfilename,
                                       batch_size=batch_size,
                                       shuffle=shuffle,
                                       normalize=normalize,
                                       with_memory_cache=with_memory_cache,
                                       with_file_cache=with_file_cache)
        check_data_iterator_result(di, batch_size, shuffle, normalize)
        di.close()
Beispiel #10
0
def download(url):
    filename = url.split('/')[-1]
    cache = os.path.join(get_data_home(), filename)
    if os.path.exists(cache):
        logger.info("> {} in cache.".format(cache))
        logger.info("> If you have any issue when using this file, ")
        logger.info("> manually remove the file and try download again.")
    else:
        r = request.urlopen(url)
        try:
            if six.PY2:
                content_length = int(r.info().dict['content-length'])
            elif six.PY3:
                content_length = int(r.info()['Content-Length'])
        except:
            content_length = 0
        unit = 1000000
        content = b''
        with tqdm(total=content_length, desc=filename) as t:
            while True:
                data = r.read(unit)
                l = len(data)
                t.update(l)
                if l == 0:
                    break
                content += data
        with open(cache, 'wb') as f:
            f.write(content)
    return open(cache, 'rb')
Beispiel #11
0
    def add(self, index):
        """Calculate time elapsed from the point previously called
        this method or this object is created to this is called.

        Args:
            index (int): Index to be displayed, and be used to take intervals.

        """
        if (index - self.flush_at) < self.interval:
            return
        now = time.time()
        elapsed = now - self.lap
        elapsed_total = now - self.start
        it = index - self.flush_at
        self.lap = now
        if self.verbose:
            logger.info("iter={} {{{}}}={}[sec/{}iter] {}[sec]".format(
                index, self.name, elapsed, it, elapsed_total))
        if self.fd is not None:
            print("{} {} {} {}".format(index, elapsed,
                                       it, elapsed_total), file=self.fd)
        self.flush_at = index
Beispiel #12
0
    def __init__(self, cachedir, shuffle=False, rng=None, normalize=False):
        super(CacheDataSource, self).__init__(shuffle=shuffle, rng=rng)
        self._cachedir = cachedir
        self._normalize = normalize
        self._filereader = FileReader(self._cachedir)
        self._filenames = self._filereader.listdir()

        self._generation = -1
        self._cache_files = []
        for filename in self._filenames:
            length = -1
            with self._filereader.open_cache(filename) as cache:
                if self._variables is None:
                    self._variables = list(cache.keys())
                for k, v in cache.items():
                    if length < 0:
                        length = len(v)
                    else:
                        assert(length == len(v))
                self._cache_files.append((filename, length))
                logger.info('{} {}'.format(filename, length))

        logger.info('{}'.format(len(self._cache_files)))
        self.reset()
Beispiel #13
0
    def __init__(self, train=True, shuffle=False, rng=None):
        super(MnistDataSource, self).__init__(shuffle=shuffle)
        self._train = train
        if self._train:
            image_uri = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
            label_uri = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
        else:
            image_uri = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'
            label_uri = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'

        logger.info('Getting label data from {}.'.format(label_uri))
        # With python3 we can write this logic as following, but with
        # python2, gzip.object does not support file-like object and
        # urllib.request does not support 'with statement'.
        #
        #   with request.urlopen(label_uri) as r, gzip.open(r) as f:
        #       _, size = struct.unpack('>II', f.read(8))
        #       self._labels = numpy.frombuffer(f.read(), numpy.uint8).reshape(-1, 1)
        #
        r = download(label_uri)
        data = zlib.decompress(r.read(), zlib.MAX_WBITS | 32)
        _, size = struct.unpack('>II', data[0:8])
        self._labels = numpy.frombuffer(data[8:], numpy.uint8).reshape(-1, 1)
        r.close()
        logger.info('Getting label data done.')

        logger.info('Getting image data from {}.'.format(image_uri))
        r = download(image_uri)
        data = zlib.decompress(r.read(), zlib.MAX_WBITS | 32)
        _, size, height, width = struct.unpack('>IIII', data[0:16])
        self._images = numpy.frombuffer(data[16:], numpy.uint8).reshape(
            size, 1, height, width)
        r.close()
        logger.info('Getting image data done.')

        self._size = self._labels.size
        self._variables = ('x', 'y')
        if rng is None:
            rng = numpy.random.RandomState(313)
        self.rng = rng
        self.reset()
Beispiel #14
0
def save(filename, contents, include_params=False):
    '''Save network definition, inference/training execution
    configurations etc.

    Args:
        filename (str): Filename to store infomation. The file
            extension is used to determine the saving file format.
            ``.nnp``: (Recomended) Creating a zip archive with nntxt (network
            definition etc.) and h5 (parameters).
            ``.nntxt``: Protobuf in text format.
            ``.protobuf'': Protobuf in binary format (unsafe in terms of
             backward compatibility).
        contents (dict): Information to store.
        include_params (bool): Includes parameter into single file. This is
            ignored when the extension of filename is nnp.

    Example:
        The current supported fields as contents are ``networks`` and
        ``executors``. The following example creates a two inputs and two
        outputs MLP, and save the network structure and the initialized
        parameters.:: python

            import nnabla as nn
            import nnabla.functions as F
            import nnabla.parametric_functions as PF

            x0 = nn.Variable([batch_size, 100])
            x1 = nn.Variable([batch_size, 100])
            h1_0 = PF.affine(x0, 100, name='affine1_0')
            h1_1 = PF.affine(x1, 100, name='affine1_0')
            h1 = F.tanh(h1_0 + h1_1)
            h2 = F.tanh(PF.affine(h1, 50, name='affine2'))
            y0 = PF.affine(h2, 10, name='affiney_0')
            y1 = PF.affine(h2, 10, name='affiney_1')

            contents = {
                'networks': [
                    {'name': 'net1',
                     'batch_size': batch_size,
                     'outputs': {'y0': y0, 'y1': y1},
                     'names': {'x0': x0, 'x1': x1}}],
                'executors': [
                    {'name': 'runtime',
                     'network': 'net1',
                     'data': ['x0', 'x1'],
                     'output': ['y0', 'y1']}]}
            save('net.nnp', contents)
    '''
    _, ext = os.path.splitext(filename)
    if ext == '.nntxt' or ext == '.prototxt':
        logger.info("Saving {} as prototxt".format(filename))
        proto = create_proto(contents, include_params)
        with open(filename, 'w') as file:
            text_format.PrintMessage(proto, file)
    elif ext == '.protobuf':
        logger.info("Saving {} as protobuf".format(filename))
        proto = create_proto(contents, include_params)
        with open(filename, 'wb') as file:
            file.write(proto.SerializeToString())
    elif ext == '.nnp':
        logger.info("Saving {} as nnp".format(filename))
        try:
            tmpdir = tempfile.mkdtemp()
            save('{}/network.nntxt'.format(tmpdir),
                 contents,
                 include_params=False)

            with open('{}/nnp_version.txt'.format(tmpdir), 'w') as file:
                file.write('{}\n'.format(nnp_version))

            save_parameters('{}/parameter.protobuf'.format(tmpdir))

            with zipfile.ZipFile(filename, 'w') as nnp:
                nnp.write('{}/nnp_version.txt'.format(tmpdir),
                          'nnp_version.txt')
                nnp.write('{}/network.nntxt'.format(tmpdir), 'network.nntxt')
                nnp.write('{}/parameter.protobuf'.format(tmpdir),
                          'parameter.protobuf')
        finally:
            shutil.rmtree(tmpdir)
Beispiel #15
0
def save(filename,
         contents,
         include_params=False,
         variable_batch_size=True,
         extension=".nnp",
         parameters=None):
    '''Save network definition, inference/training execution
    configurations etc.

    Args:
        filename (str or file object): Filename to store information. The file
            extension is used to determine the saving file format.
            ``.nnp``: (Recommended) Creating a zip archive with nntxt (network
            definition etc.) and h5 (parameters).
            ``.nntxt``: Protobuf in text format.
            ``.protobuf``: Protobuf in binary format (unsafe in terms of
             backward compatibility).
        contents (dict): Information to store.
        include_params (bool): Includes parameter into single file. This is
            ignored when the extension of filename is nnp.
        variable_batch_size (bool):
            By ``True``, the first dimension of all variables is considered
            as batch size, and left as a placeholder
            (more specifically ``-1``). The placeholder dimension will be
            filled during/after loading.
        extension: if files is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp".

    Example:
        The following example creates a two inputs and two
        outputs MLP, and save the network structure and the initialized
        parameters.

        .. code-block:: python

            import nnabla as nn
            import nnabla.functions as F
            import nnabla.parametric_functions as PF
            from nnabla.utils.save import save

            batch_size = 16
            x0 = nn.Variable([batch_size, 100])
            x1 = nn.Variable([batch_size, 100])
            h1_0 = PF.affine(x0, 100, name='affine1_0')
            h1_1 = PF.affine(x1, 100, name='affine1_0')
            h1 = F.tanh(h1_0 + h1_1)
            h2 = F.tanh(PF.affine(h1, 50, name='affine2'))
            y0 = PF.affine(h2, 10, name='affiney_0')
            y1 = PF.affine(h2, 10, name='affiney_1')

            contents = {
                'networks': [
                    {'name': 'net1',
                     'batch_size': batch_size,
                     'outputs': {'y0': y0, 'y1': y1},
                     'names': {'x0': x0, 'x1': x1}}],
                'executors': [
                    {'name': 'runtime',
                     'network': 'net1',
                     'data': ['x0', 'x1'],
                     'output': ['y0', 'y1']}]}
            save('net.nnp', contents)


        To get a trainable model, use following code instead.

        .. code-block:: python

            contents = {
            'global_config': {'default_context': ctx},
            'training_config':
                {'max_epoch': args.max_epoch,
                 'iter_per_epoch': args_added.iter_per_epoch,
                 'save_best': True},
            'networks': [
                {'name': 'training',
                 'batch_size': args.batch_size,
                 'outputs': {'loss': loss_t},
                 'names': {'x': x, 'y': t, 'loss': loss_t}},
                {'name': 'validation',
                 'batch_size': args.batch_size,
                 'outputs': {'loss': loss_v},
                 'names': {'x': x, 'y': t, 'loss': loss_v}}],
            'optimizers': [
                {'name': 'optimizer',
                 'solver': solver,
                 'network': 'training',
                 'dataset': 'mnist_training',
                 'weight_decay': 0,
                 'lr_decay': 1,
                 'lr_decay_interval': 1,
                 'update_interval': 1}],
            'datasets': [
                {'name': 'mnist_training',
                 'uri': 'MNIST_TRAINING',
                 'cache_dir': args.cache_dir + '/mnist_training.cache/',
                 'variables': {'x': x, 'y': t},
                 'shuffle': True,
                 'batch_size': args.batch_size,
                 'no_image_normalization': True},
                {'name': 'mnist_validation',
                 'uri': 'MNIST_VALIDATION',
                 'cache_dir': args.cache_dir + '/mnist_test.cache/',
                 'variables': {'x': x, 'y': t},
                 'shuffle': False,
                 'batch_size': args.batch_size,
                 'no_image_normalization': True
                 }],
            'monitors': [
                {'name': 'training_loss',
                 'network': 'validation',
                 'dataset': 'mnist_training'},
                {'name': 'validation_loss',
                 'network': 'validation',
                 'dataset': 'mnist_validation'}],
            }


    '''
    ctx = FileHandlerContext()
    if isinstance(filename, str):
        _, ext = os.path.splitext(filename)
    else:
        ext = extension
    include_params = False if ext == '.nnp' else include_params
    ctx.proto = create_proto(contents, include_params, variable_batch_size)
    ctx.parameters = parameters
    file_savers = get_default_file_savers()
    save_files(ctx, file_savers, filename, ext)
    logger.info("Model file is saved as ({}): {}".format(ext, filename))
def test_imsave_and_imread(tmpdir, backend, grayscale, size, channel_first,
                           as_uint16, num_channels, auto_scale, img):
    # import pdb
    # pdb.set_trace()
    # preprocess
    _change_backend(backend)

    tmpdir.ensure(dir=True)
    tmppath = tmpdir.join("tmp.png")
    img_file = tmppath.strpath

    ref_size_axis = 0
    if channel_first and len(img.shape) == 3:
        img = img.transpose((2, 0, 1))
        ref_size_axis = 1

    # do imsave
    def save_image_function():
        image_utils.imsave(img_file,
                           img,
                           channel_first=channel_first,
                           as_uint16=as_uint16,
                           auto_scale=auto_scale)

    if check_imsave_condition(backend, img, as_uint16, auto_scale):
        save_image_function()
    else:
        with pytest.raises(ValueError):
            save_image_function()

        return True

    # do imread
    def read_image_function():
        return image_utils.imread(img_file,
                                  grayscale=grayscale,
                                  size=size,
                                  channel_first=channel_first,
                                  as_uint16=as_uint16,
                                  num_channels=num_channels)

    if not grayscale and num_channels in [0, 1]:
        with pytest.raises(ValueError):
            _ = read_image_function()

        return True
    else:
        read_image = read_image_function()

    logger.info(read_image.shape)
    # ---check size---
    ref_size = img.shape[ref_size_axis:ref_size_axis +
                         2] if size is None else size
    size_axis = 1 if len(read_image.shape) == 3 and channel_first else 0
    assert read_image.shape[size_axis:size_axis + 2] == ref_size

    # ---check channels---
    if num_channels == 0 or (num_channels == -1 and
                             (len(img.shape) == 2 or grayscale)):
        assert len(read_image.shape) == 2
    else:
        channel_axis = 0 if channel_first else -1
        ref_channels = num_channels if num_channels > 0 else img.shape[
            channel_axis]
        assert read_image.shape[channel_axis] == ref_channels

    # ---check dtype---
    if as_uint16 or img.dtype == np.uint16:
        assert read_image.dtype == np.uint16
    else:
        assert read_image.dtype == np.uint8

    # ---check close between before imsave and after imread---
    if size is None and not grayscale and img.shape == read_image.shape:
        scaler = get_scale_factor(img, auto_scale, as_uint16)
        dtype = img.dtype if img.dtype in [np.uint8, np.uint16] else np.float32

        assert_allclose((img.astype(dtype) * scaler).astype(read_image.dtype),
                        read_image)
Beispiel #17
0
def set_global_seed(seed: int) -> None:
    np.random.seed(seed=seed)
    py_random.seed(seed)
    nn.seed(seed)
    logger.info("Set seed to {}".format(seed))
Beispiel #18
0
def main():
    """
    Main script.
    Steps:
    * Setup calculation environment
    * Initialize data iterator.
    * Create Networks
    * Create Solver.
    * Training Loop.
    *   Training
    *   Test
    * Save
    """

    # Set args
    args = get_args(monitor_path='tmp.monitor.vae',
                    max_iter=60000,
                    model_save_path=None,
                    learning_rate=3e-4,
                    batch_size=100,
                    weight_decay=0)

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Initialize data provider
    di_l = data_iterator_mnist(args.batch_size, True)
    di_t = data_iterator_mnist(args.batch_size, False)

    # Network
    shape_x = (1, 28, 28)
    shape_z = (50, )
    x = nn.Variable((args.batch_size, ) + shape_x)
    loss_l = vae(x, shape_z, test=False)
    loss_t = vae(x, shape_z, test=True)

    # Create solver
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Monitors for training and validation
    monitor = M.Monitor(args.model_save_path)
    monitor_training_loss = M.MonitorSeries("Training loss",
                                            monitor,
                                            interval=600)
    monitor_test_loss = M.MonitorSeries("Test loss", monitor, interval=600)
    monitor_time = M.MonitorTimeElapsed("Elapsed time", monitor, interval=600)

    # Training Loop.
    for i in range(args.max_iter):

        # Initialize gradients
        solver.zero_grad()

        # Forward, backward and update
        x.d, _ = di_l.next()
        loss_l.forward(clear_no_need_grad=True)
        loss_l.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()

        # Forward for test
        x.d, _ = di_t.next()
        loss_t.forward(clear_no_need_grad=True)

        # Monitor for logging
        monitor_training_loss.add(i, loss_l.d.copy())
        monitor_test_loss.add(i, loss_t.d.copy())
        monitor_time.add(i)

    # Save the model
    nn.save_parameters(
        os.path.join(args.model_save_path, 'params_%06d.h5' % args.max_iter))
Beispiel #19
0
def load(filenames,
         prepare_data_iterator=True,
         batch_size=None,
         exclude_parameter=False,
         parameter_only=False,
         extension=".nntxt"):
    '''load
    Load network information from files.

    Args:
        filenames (list): file-like object or List of filenames.
        extension: if filenames is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp".
    Returns:
        dict: Network information.
    '''
    class Info:
        pass

    info = Info()

    proto = nnabla_pb2.NNablaProtoBuf()

    # optimizer checkpoint
    opti_proto = nnabla_pb2.NNablaProtoBuf()
    OPTI_BUF_EXT = ['.optimizer']
    opti_h5_files = {}
    tmpdir = tempfile.mkdtemp()

    if isinstance(filenames, list) or isinstance(filenames, tuple):
        pass
    elif isinstance(filenames, str) or hasattr(filenames, 'read'):
        filenames = [filenames]

    for filename in filenames:
        if isinstance(filename, str):
            _, ext = os.path.splitext(filename)
        else:
            ext = extension

        # TODO: Here is some known problems.
        #   - Even when protobuf file includes network structure,
        #     it will not loaded.
        #   - Even when prototxt file includes parameter,
        #     it will not loaded.

        if ext in ['.nntxt', '.prototxt']:
            if not parameter_only:
                with get_file_handle_load(filename, ext) as f:
                    try:
                        text_format.Merge(f.read(), proto)
                    except:
                        logger.critical('Failed to read {}.'.format(filename))
                        logger.critical(
                            '2 byte characters may be used for file name or folder name.'
                        )
                        raise
            if len(proto.parameter) > 0:
                if not exclude_parameter:
                    nn.load_parameters(filename, extension=ext)
        elif ext in ['.protobuf', '.h5']:
            if not exclude_parameter:
                nn.load_parameters(filename, extension=ext)
            else:
                logger.info('Skip loading parameter.')

        elif ext == '.nnp':
            with get_file_handle_load(filename, ext) as nnp:
                for name in nnp.namelist():
                    _, ext = os.path.splitext(name)
                    if name == 'nnp_version.txt':
                        pass  # TODO currently do nothing with version.
                    elif ext in ['.nntxt', '.prototxt']:
                        if not parameter_only:
                            with nnp.open(name, 'r') as f:
                                text_format.Merge(f.read(), proto)
                        if len(proto.parameter) > 0:
                            if not exclude_parameter:
                                with nnp.open(name, 'r') as f:
                                    nn.load_parameters(f, extension=ext)
                    elif ext in ['.protobuf', '.h5']:
                        if not exclude_parameter:
                            with nnp.open(name, 'r') as f:
                                nn.load_parameters(f, extension=ext)
                        else:
                            logger.info('Skip loading parameter.')
                    elif ext in OPTI_BUF_EXT:
                        buf_type = get_buf_type(name)
                        if buf_type == 'protobuf':
                            with nnp.open(name, 'r') as f:
                                with get_file_handle_load(
                                        f, '.protobuf') as opti_p:
                                    opti_proto.MergeFromString(opti_p.read())
                        elif buf_type == 'h5':
                            nnp.extract(name, tmpdir)
                            opti_h5_files[name] = os.path.join(tmpdir, name)

    default_context = None
    if proto.HasField('global_config'):
        info.global_config = _global_config(proto)
        default_context = info.global_config.default_context
        if 'cuda' in default_context.backend:
            import nnabla_ext.cudnn
        elif 'cuda:float' in default_context.backend:
            try:
                import nnabla_ext.cudnn
            except:
                pass
        try:
            x = nn.Variable()
            y = nn.Variable()
            func = F.ReLU(default_context, inplace=True)
            func.setup([x], [y])
            func.forward([x], [y])
        except:
            logger.warn('Fallback to CPU context.')
            import nnabla_ext.cpu
            default_context = nnabla_ext.cpu.context()
    else:
        import nnabla_ext.cpu
        default_context = nnabla_ext.cpu.context()

    comm = current_communicator()
    if comm:
        default_context.device_id = str(comm.local_rank)
    if proto.HasField('training_config'):
        info.training_config = _training_config(proto)

    info.datasets = _datasets(
        proto, prepare_data_iterator if prepare_data_iterator is not None else
        info.training_config.max_epoch > 0)

    info.networks = _networks(proto, default_context, batch_size)

    info.optimizers = _optimizers(proto, default_context, info.networks,
                                  info.datasets)
    _load_optimizer_checkpoint(opti_proto, opti_h5_files, info)
    shutil.rmtree(tmpdir)

    info.monitors = _monitors(proto, default_context, info.networks,
                              info.datasets)

    info.executors = _executors(proto, info.networks)

    return info
Beispiel #20
0
    def __init__(self,
                 data_source,
                 cache_dir=None,
                 cache_file_name_prefix='cache',
                 shuffle=False,
                 rng=None):
        self._tempdir_created = False
        logger.info('Using DataSourceWithFileCache')
        super(DataSourceWithFileCache, self).__init__(shuffle=shuffle, rng=rng)
        self._cache_file_name_prefix = cache_file_name_prefix
        self._cache_dir = cache_dir
        logger.info('Cache Directory is {}'.format(self._cache_dir))

        self._cache_size = int(
            nnabla_config.get('DATA_ITERATOR', 'data_source_file_cache_size'))
        logger.info('Cache size is {}'.format(self._cache_size))

        self._num_of_threads = int(
            nnabla_config.get('DATA_ITERATOR',
                              'data_source_file_cache_num_of_threads'))
        logger.info('Num of thread is {}'.format(self._num_of_threads))

        self._cache_file_format = nnabla_config.get('DATA_ITERATOR',
                                                    'cache_file_format')
        logger.info('Cache file format is {}'.format(self._cache_file_format))

        self._thread_lock = threading.Lock()

        self._size = data_source._size
        self._variables = data_source.variables
        self._data_source = data_source
        self._generation = -1
        self._cache_positions = []
        self._total_cached_size = 0
        self._cache_file_names = []
        self._cache_file_order = []
        self._cache_file_start_positions = []
        self._cache_file_data_orders = []

        self._current_cache_file_index = -1
        self._current_cache_data = None

        self.shuffle = shuffle
        self._order = list(range(self._size))

        # __enter__
        if self._cache_dir is None:
            self._tempdir_created = True
            if nnabla_config.get('DATA_ITERATOR',
                                 'data_source_file_cache_location') != '':
                self._cache_dir = tempfile.mkdtemp(dir=nnabla_config.get(
                    'DATA_ITERATOR', 'data_source_file_cache_location'))
            else:
                self._cache_dir = tempfile.mkdtemp()
            logger.info('Tempdir for cache {} created.'.format(
                self._cache_dir))
        self._closed = False
        atexit.register(self.close)

        self._create_cache()
        self._create_cache_file_position_table()
Beispiel #21
0
    def __init__(self, train=True, shuffle=False, rng=None):
        super(Cifar10DataSource, self).__init__(shuffle=shuffle)

        # Lock
        lockfile = os.path.join(get_data_home(), "cifar10.lock")
        start_time = time.time()
        while True:  # busy-lock due to communication between process spawn by mpirun
            try:
                fd = os.open(lockfile, os.O_CREAT | os.O_EXCL | os.O_RDWR)
                break
            except OSError as e:
                if e.errno != errno.EEXIST:
                    raise
                if (time.time() - start_time) >= 60 * 30:  # wait for 30min
                    raise Exception(
                        "Timeout occured. If there are cifar10.lock in $HOME/nnabla_data, it should be deleted.")

            time.sleep(5)

        self._train = train
        data_uri = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
        logger.info('Getting labeled data from {}.'.format(data_uri))
        r = download(data_uri)  # file object returned
        with tarfile.open(fileobj=r, mode="r:gz") as fpin:
            # Training data
            if train:
                images = []
                labels = []
                for member in fpin.getmembers():
                    if "data_batch" not in member.name:
                        continue
                    fp = fpin.extractfile(member)
                    data = np.load(fp, encoding="bytes")
                    images.append(data[b"data"])
                    labels.append(data[b"labels"])
                self._size = 50000
                self._images = np.concatenate(
                    images).reshape(self._size, 3, 32, 32)
                self._labels = np.concatenate(labels).reshape(-1, 1)
            # Validation data
            else:
                for member in fpin.getmembers():
                    if "test_batch" not in member.name:
                        continue
                    fp = fpin.extractfile(member)
                    data = np.load(fp, encoding="bytes")
                    images = data[b"data"]
                    labels = data[b"labels"]
                self._size = 10000
                self._images = images.reshape(self._size, 3, 32, 32)
                self._labels = np.array(labels).reshape(-1, 1)
        r.close()
        logger.info('Getting labeled data from {}.'.format(data_uri))

        self._size = self._labels.size
        self._variables = ('x', 'y')
        if rng is None:
            rng = np.random.RandomState(313)
        self.rng = rng
        self.reset()

        # Unlock
        os.close(fd)
        os.unlink(lockfile)
Beispiel #22
0
def main():
    """
    Main script.

    Steps:
    * Get and set context.
    * Load Dataset
    * Initialize DataIterator.
    * Create Networks
    *   Net for Labeled Data
    *   Net for Unlabeled Data
    *   Net for Test Data
    * Create Solver.
    * Training Loop.
    *   Test
    *   Training
    *     by Labeled Data
    *       Calculate Supervised Loss
    *     by Unlabeled Data
    *       Calculate Virtual Adversarial Noise
    *       Calculate Unsupervised Loss
    """

    args = get_args()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    shape_x = (1, 28, 28)
    n_h = args.n_units
    n_y = args.n_class

    # Load MNIST Dataset
    from mnist_data import load_mnist, data_iterator_mnist
    images, labels = load_mnist(train=True)
    rng = np.random.RandomState(706)
    inds = rng.permutation(len(images))

    def feed_labeled(i):
        j = inds[i]
        return images[j], labels[j]

    def feed_unlabeled(i):
        j = inds[i]
        return images[j], labels[j]

    di_l = data_iterator_simple(feed_labeled,
                                args.n_labeled,
                                args.batchsize_l,
                                shuffle=True,
                                rng=rng,
                                with_file_cache=False)
    di_u = data_iterator_simple(feed_unlabeled,
                                args.n_train,
                                args.batchsize_u,
                                shuffle=True,
                                rng=rng,
                                with_file_cache=False)
    di_v = data_iterator_mnist(args.batchsize_v, train=False)

    # Create networks
    # feed-forward-net building function
    def forward(x, test=False):
        return mlp_net(x, n_h, n_y, test)

    # Net for learning labeled data
    xl = nn.Variable((args.batchsize_l, ) + shape_x, need_grad=False)
    yl = forward(xl, test=False)
    tl = nn.Variable((args.batchsize_l, 1), need_grad=False)
    loss_l = F.mean(F.softmax_cross_entropy(yl, tl))

    # Net for learning unlabeled data
    xu = nn.Variable((args.batchsize_u, ) + shape_x, need_grad=False)
    yu = forward(xu, test=False)
    y1 = yu.get_unlinked_variable()
    y1.need_grad = False

    noise = nn.Variable((args.batchsize_u, ) + shape_x, need_grad=True)
    r = noise / (F.sum(noise**2, [1, 2, 3], keepdims=True))**0.5
    r.persistent = True
    y2 = forward(xu + args.xi_for_vat * r, test=False)
    y3 = forward(xu + args.eps_for_vat * r, test=False)
    loss_k = F.mean(distance(y1, y2))
    loss_u = F.mean(distance(y1, y3))

    # Net for evaluating validation data
    xv = nn.Variable((args.batchsize_v, ) + shape_x, need_grad=False)
    hv = forward(xv, test=True)
    tv = nn.Variable((args.batchsize_v, 1), need_grad=False)
    err = F.mean(F.top_n_error(hv, tv, n=1))

    # Create solver
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Monitor training and validation stats.
    import nnabla.monitor as M
    monitor = M.Monitor(args.model_save_path)
    monitor_verr = M.MonitorSeries("Test error", monitor, interval=240)
    monitor_time = M.MonitorTimeElapsed("Elapsed time", monitor, interval=240)

    # Training Loop.
    t0 = time.time()

    for i in range(args.max_iter):

        # Validation Test
        if i % args.val_interval == 0:
            valid_error = calc_validation_error(di_v, xv, tv, err,
                                                args.val_iter)
            monitor_verr.add(i, valid_error)

        #################################
        ## Training by Labeled Data #####
        #################################

        # forward, backward and update
        xl.d, tl.d = di_l.next()
        xl.d = xl.d / 255
        solver.zero_grad()
        loss_l.forward(clear_no_need_grad=True)
        loss_l.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()

        #################################
        ## Training by Unlabeled Data ###
        #################################

        # Calculate y without noise, only once.
        xu.d, _ = di_u.next()
        xu.d = xu.d / 255
        yu.forward(clear_buffer=True)

        ##### Calculate Adversarial Noise #####
        # Do power method iteration
        noise.d = np.random.normal(size=xu.shape).astype(np.float32)
        for k in range(args.n_iter_for_power_method):
            r.grad.zero()
            loss_k.forward(clear_no_need_grad=True)
            loss_k.backward(clear_buffer=True)
            noise.data.copy_from(r.grad)

        ##### Calculate loss for unlabeled data #####
        # forward, backward and update
        solver.zero_grad()
        loss_u.forward(clear_no_need_grad=True)
        loss_u.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()

        ##### Learning rate update #####
        if i % args.iter_per_epoch == 0:
            solver.set_learning_rate(solver.learning_rate() *
                                     args.learning_rate_decay)
        monitor_time.add(i)

    # Evaluate the final model by the error rate with validation dataset
    valid_error = calc_validation_error(di_v, xv, tv, err, args.val_iter)
    monitor_verr.add(i, valid_error)
    monitor_time.add(i)

    # Save the model.
    parameter_file = os.path.join(args.model_save_path,
                                  'params_%06d.h5' % args.max_iter)
    nn.save_parameters(parameter_file)
Beispiel #23
0
def _protobuf_file_saver(ctx, filename, ext):
    logger.info("Saving {} as protobuf".format(filename))
    with get_file_handle_save(filename, ext) as f:
        f.write(ctx.proto.SerializeToString())
Beispiel #24
0
def _nntxt_file_saver(ctx, filename, ext):
    logger.info("Saving {} as prototxt".format(filename))
    with get_file_handle_save(filename, ext) as f:
        text_format.PrintMessage(ctx.proto, f)
Beispiel #25
0
def multiprocess_save_cache(create_cache_args):
    def _process_row(row, args):
        def _get_value(value, is_vector=False):
            try:
                if is_vector:
                    value = [float(value)]
                else:
                    value = float(value)
                return value
            except ValueError:
                pass
            ext = (os.path.splitext(value)[1]).lower()
            with args._filereader.open(value) as f:
                value = load(ext)(f, normalize=args._normalize)
            return value

        values = collections.OrderedDict()
        if len(row) == len(args._columns):
            for column, column_value in enumerate(row):
                variable, index, label = args._columns[column]
                if index is None:
                    values[variable] = _get_value(column_value, is_vector=True)
                else:
                    if variable not in values:
                        values[variable] = []
                    values[variable].append(_get_value(column_value))
        return values.values()

    (position, cache_csv), cc_args = create_cache_args
    cc_args = SimpleNamespace(**cc_args)
    cache_data = []
    for row in cache_csv:
        cache_data.append(tuple(_process_row(row, cc_args)))

    if len(cache_data) > 0:
        start_position = position + 1 - len(cache_data)
        end_position = position
        cache_filename = os.path.join(
            cc_args._output_cache_dirname,
            '{}_{:08d}_{:08d}{}'.format(cc_args._cache_file_name_prefix,
                                        start_position, end_position,
                                        cc_args._cache_file_format))

        logger.info('Creating cache file {}'.format(cache_filename))

        data = collections.OrderedDict([(n, []) for n in cc_args._variables])
        for _, cd in enumerate(cache_data):
            for i, n in enumerate(cc_args._variables):
                if isinstance(cd[i], numpy.ndarray):
                    d = cd[i]
                else:
                    d = numpy.array(cd[i]).astype(numpy.float32)
                data[n].append(d)
        try:
            if cc_args._cache_file_format == ".h5":
                h5 = h5py.File(cache_filename, 'w')
                for k, v in data.items():
                    h5.create_dataset(k, data=v)
                h5.close()
            else:
                retry_count = 1
                is_create_cache_incomplete = True
                while is_create_cache_incomplete:
                    try:
                        with open(cache_filename, 'wb') as f:
                            for v in data.values():
                                numpy.save(f, v)
                        is_create_cache_incomplete = False
                    except OSError:
                        retry_count += 1
                        if retry_count > 10:
                            raise
                        logger.info(
                            'Creating cache retry {}/10'.format(retry_count))
        except:
            logger.critical(
                'An error occurred while creating cache file from dataset.')
            for k, v in data.items():
                size = v[0].shape
                for d in v:
                    if size != d.shape:
                        logger.critical(
                            'The sizes of data "{}" are not the same. ({} != {})'
                            .format(k, size, d.shape))
            raise

        cc_args._cache_file_name_and_data_nums_list.append(
            (cache_filename, len(cache_data)))
        progress(
            'Create cache',
            len(cc_args._cache_file_name_and_data_nums_list) /
            cc_args._cache_file_count)
Beispiel #26
0
    def create(self,
               output_cache_dirname,
               normalize=True,
               cache_file_name_prefix='cache'):

        cache_file_format = nnabla_config.get('DATA_ITERATOR',
                                              'cache_file_format')
        logger.info('Cache file format is {}'.format(cache_file_format))

        progress(None)

        cache_file_name_and_data_nums_list = multiprocessing.Manager().list()

        csv_position_and_data = []
        csv_row = []
        for _position in range(self._size):
            csv_row.append(self._csv_data[self._order[_position]])
            if len(csv_row) == self._cache_size:
                csv_position_and_data.append((_position, csv_row))
                csv_row = []
        if len(csv_row):
            csv_position_and_data.append((self._size - 1, csv_row))

        self_args = {
            '_cache_file_name_prefix': cache_file_name_prefix,
            '_cache_file_format': cache_file_format,
            '_cache_file_name_and_data_nums_list':
            cache_file_name_and_data_nums_list,
            '_output_cache_dirname': output_cache_dirname,
            '_variables': self._variables,
            '_filereader': self._filereader,
            '_normalize': normalize,
            '_columns': self._columns,
            '_cache_file_count': len(csv_position_and_data)
        }

        # Notice:
        #   Here, we have to place a gc.collect(), since we found
        #   python might perform garbage collection operation in
        #   a child process, which tends to release some objects
        #   created by its parent process, thus, it might touch
        #   cuda APIs which has not initialized in child process.
        #   Place a gc.collect() here can avoid such cases.
        gc.collect()

        progress('Create cache', 0)
        with closing(multiprocessing.Pool(self._process_num)) as pool:
            pool.map(multiprocess_save_cache,
                     ((i, self_args) for i in csv_position_and_data))
        progress('Create cache', 1.0)

        logger.info('The total of cache files is {}'.format(
            len(cache_file_name_and_data_nums_list)))

        # Create Index
        index_filename = os.path.join(output_cache_dirname, "cache_index.csv")
        cache_index_rows = sorted(cache_file_name_and_data_nums_list,
                                  key=lambda x: x[0])
        with open(index_filename, 'w') as f:
            writer = csv.writer(f, lineterminator='\n')
            for file_name, data_nums in cache_index_rows:
                writer.writerow((os.path.basename(file_name), data_nums))

        # Create Info
        if cache_file_format == ".npy":
            info_filename = os.path.join(output_cache_dirname,
                                         "cache_info.csv")
            with open(info_filename, 'w') as f:
                writer = csv.writer(f, lineterminator='\n')
                for variable in self._variables:
                    writer.writerow((variable, ))

        # Create original.csv
        if self._original_source_uri is not None:
            shutil.copy(self._original_source_uri,
                        os.path.join(output_cache_dirname, "original.csv"))

        # Create order.csv
        if self._order is not None and \
                self._original_order is not None:
            with open(os.path.join(output_cache_dirname, "order.csv"),
                      'w') as o:
                writer = csv.writer(o, lineterminator='\n')
                for orders in zip(self._original_order, self._order):
                    writer.writerow(list(orders))
Beispiel #27
0
def main(**kwargs):
    # set training args
    args = AttrDict(kwargs)

    assert os.path.exists(
        args.config
    ), f"{args.config} is not found. Please make sure the config file exists."
    conf = read_yaml(args.config)

    comm = init_nnabla(ext_name="cudnn",
                       device_id=args.device_id,
                       type_config="float",
                       random_pseed=True)
    if args.sampling_interval is None:
        args.sampling_interval = 1

    use_timesteps = list(
        range(0, conf.num_diffusion_timesteps, args.sampling_interval))
    if use_timesteps[-1] != conf.num_diffusion_timesteps - 1:
        # The last step should be included always.
        use_timesteps.append(conf.num_diffusion_timesteps - 1)

    # setup model variance type
    model_var_type = ModelVarType.FIXED_SMALL
    if "model_var_type" in conf:
        model_var_type = ModelVarType.get_vartype_from_key(conf.model_var_type)

    model = Model(beta_strategy=conf.beta_strategy,
                  use_timesteps=use_timesteps,
                  model_var_type=model_var_type,
                  num_diffusion_timesteps=conf.num_diffusion_timesteps,
                  attention_num_heads=conf.num_attention_heads,
                  attention_resolutions=conf.attention_resolutions,
                  scale_shift_norm=conf.ssn,
                  base_channels=conf.base_channels,
                  channel_mult=conf.channel_mult,
                  num_res_blocks=conf.num_res_blocks)

    # load parameters
    assert os.path.exists(
        args.h5
    ), f"{args.h5} is not found. Please make sure the h5 file exists."
    nn.parameter.load_parameters(args.h5)

    # Generate
    # sampling
    B = args.batch_size
    num_samples_per_iter = B * comm.n_procs
    num_iter = (args.samples + num_samples_per_iter -
                1) // num_samples_per_iter

    local_saved_cnt = 0
    for i in range(num_iter):
        logger.info(f"Generate samples {i + 1} / {num_iter}.")
        sample_out, _, x_starts = model.sample(shape=(B, ) +
                                               conf.image_shape[1:],
                                               dump_interval=1,
                                               use_ema=args.ema,
                                               progress=comm.rank == 0,
                                               use_ddim=args.ddim)

        # scale back to [0, 255]
        sample_out = (sample_out + 1) * 127.5

        if args.tiled:
            save_path = os.path.join(args.output_dir,
                                     f"gen_{local_saved_cnt}_{comm.rank}.png")
            save_tiled_image(sample_out.astype(np.uint8), save_path)
            local_saved_cnt += 1
        else:
            for b in range(B):
                save_path = os.path.join(
                    args.output_dir, f"gen_{local_saved_cnt}_{comm.rank}.png")
                imsave(save_path,
                       sample_out[b].astype(np.uint8),
                       channel_first=True)
                local_saved_cnt += 1

        # create video for x_starts
        if args.save_xstart:
            clips = []
            for i in range(len(x_starts)):
                xstart = x_starts[i][1]
                assert isinstance(xstart, np.ndarray)
                im = get_tiled_image(np.clip((xstart + 1) * 127.5, 0, 255),
                                     channel_last=False).astype(np.uint8)
                clips.append(im)

            clip = mp.ImageSequenceClip(clips, fps=5)
            clip.write_videofile(
                os.path.join(
                    args.output_dir,
                    f"pred_x0_along_time_{local_saved_cnt}_{comm.rank}.mp4"))
Beispiel #28
0
    def _save_cache_to_file(self):
        '''
        Store cache data into file.

        Data will be stored as hdf5 format, placed at config..
        Cache file name format is "cache_START_END.h5"
        '''
        if self._cache_dir is None:
            raise DataSourceWithFileCacheError(
                'Use this class with "with statement" if you don\'t specify cache dir.'
            )
        cache_data = OrderedDict()

        def get_data(args):
            pos = args[0]
            q = args[1]
            retry = 1
            while True:
                if retry > 10:
                    logger.log(
                        99, '_get_current_data() retry count over give up.')
                    raise
                d = self._data_source._get_data(pos)
                if d is not None:
                    break
                logger.log(
                    99,
                    '_get_data() fails. retrying count {}/10.'.format(retry))
                retry += 1

            q.put((pos, d))

        q = Queue()
        with closing(ThreadPool(processes=self._num_of_threads)) as pool:
            pool.map(get_data, [(pos, q) for pos in self._cache_positions])

        while len(cache_data) < len(self._cache_positions):
            index, data = q.get()
            cache_data[index] = data
        start_position = self.position - len(cache_data) + 1
        end_position = self.position
        cache_filename = os.path.join(
            self._cache_dir,
            '{}_{:08d}_{:08d}{}'.format(self._cache_file_name_prefix,
                                        start_position, end_position,
                                        self._cache_file_format))

        data = OrderedDict([(n, []) for n in self._data_source.variables])
        for pos in sorted(cache_data):
            cd = cache_data[pos]
            for i, n in enumerate(self._data_source.variables):
                if isinstance(cd[i], numpy.ndarray):
                    d = cd[i]
                else:
                    d = numpy.array(cd[i]).astype(numpy.float32)
                data[n].append(d)

        logger.info('Creating cache file {}'.format(cache_filename))
        if self._cache_file_format == ".h5":
            h5 = h5py.File(cache_filename, 'w')
            for k, v in data.items():
                h5.create_dataset(k, data=v)
            h5.close()
        else:
            retry_count = 1
            is_create_cache_imcomplete = True
            while is_create_cache_imcomplete:
                try:
                    with open(cache_filename, 'wb') as f:
                        for v in data.values():
                            numpy.save(f, v)
                    is_create_cache_imcomplete = False
                except OSError:
                    retry_count += 1
                    if retry_count > 10:
                        raise
                    logger.info(
                        'Creating cache retry {}/10'.format(retry_count))

        self._cache_file_names.append(cache_filename)
        self._cache_file_order.append(len(self._cache_file_order))
        self._cache_file_data_orders.append(list(range(len(cache_data))))
        self._cache_positions = []
Beispiel #29
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--context',
        '-c',
        type=str,
        default='cudnn',
        help="Extension module. 'cudnn' is highly recommended.")
    parser.add_argument("--device-id",
                        "-d",
                        type=str,
                        default='-1',
                        help='A list of device ids to use, e.g., `0,1,2,3`.\
                        This is only valid if you specify `-c cudnn`.')
    parser.add_argument("--type-config",
                        "-t",
                        type=str,
                        default='float',
                        help='Type configuration.')
    parser.add_argument('--search',
                        '-s',
                        action='store_true',
                        help='Whether it is searching for the architecture.')
    parser.add_argument('--algorithm',
                        '-a',
                        type=str,
                        default='DartsSeacher',
                        choices=runner.__all__,
                        help='Which algorithm to use.')
    parser.add_argument('--config-file',
                        '-f',
                        type=str,
                        help='The configuration file for the experiment.')
    parser.add_argument('--output-path',
                        '-o',
                        type=str,
                        help='Path to save the monitoring log files.')

    options = parser.parse_args()

    config = json.load(open(
        options.config_file)) if options.config_file else dict()
    hparams = config['hparams']

    hparams.update({k: v for k, v in vars(options).items() if v is not None})

    # setup cuda visible
    if hparams['device_id'] != '-1':
        os.environ["CUDA_VISIBLE_DEVICES"] = hparams['device_id']

    # setup context for nnabla
    ctx = get_extension_context(hparams['context'],
                                device_id='0',
                                type_config=hparams['type_config'])

    # setup for distributed training
    hparams['comm'] = CommunicatorWrapper(ctx)
    hparams['event'] = StreamEventHandler(int(hparams['comm'].ctx.device_id))

    nn.set_default_context(hparams['comm'].ctx)

    if hparams['comm'].n_procs > 1 and hparams['comm'].rank == 0:
        n_procs = hparams['comm'].n_procs
        logger.info(f'Distributed training with {n_procs} processes.')

    # build the model
    name, attributes = list(config['network'].items())[0]
    algorithm = contrib.__dict__[name]
    model = algorithm.SearchNet(**attributes) if hparams['search'] else \
        algorithm.TrainNet(**attributes)

    # Get all arguments for the runner
    conf = Configuration(config)

    runner.__dict__[hparams['algorithm']](model,
                                          optimizer=conf.optimizer,
                                          dataloader=conf.dataloader,
                                          args=conf.hparams).run()
Beispiel #30
0
    def __init__(self, train=True, shuffle=False, rng=None):
        super(Cifar100DataSource, self).__init__(shuffle=shuffle)

        # Lock
        lockfile = os.path.join(get_data_home(), "cifar100.lock")
        start_time = time.time()
        while True:  # busy-lock due to communication between process spawn by mpirun
            try:
                fd = os.open(lockfile, os.O_CREAT | os.O_EXCL | os.O_RDWR)
                break
            except OSError as e:
                if e.errno != errno.EEXIST:
                    raise
                if (time.time() - start_time) >= 60 * 30:  # wait for 30min
                    raise Exception(
                        "Timeout occured. If there are cifar10.lock in $HOME/nnabla_data, it should be deleted."
                    )

            time.sleep(5)

        self._train = train
        data_uri = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
        logger.info('Getting labeled data from {}.'.format(data_uri))
        r = download(data_uri)  # file object returned
        with tarfile.open(fileobj=r, mode="r:gz") as fpin:
            # Training data
            if train:
                images = []
                labels = []
                for member in fpin.getmembers():
                    if "train" not in member.name:
                        continue
                    fp = fpin.extractfile(member)
                    data = np.load(fp, encoding="bytes")
                    images = data[b"data"]
                    labels = data[b"fine_labels"]
                self._size = 50000
                self._images = images.reshape(self._size, 3, 32, 32)
                self._labels = np.array(labels).reshape(-1, 1)
            # Validation data
            else:
                for member in fpin.getmembers():
                    if "test" not in member.name:
                        continue
                    fp = fpin.extractfile(member)
                    data = np.load(fp, encoding="bytes")
                    images = data[b"data"]
                    labels = data[b"fine_labels"]
                self._size = 10000
                self._images = images.reshape(self._size, 3, 32, 32)
                self._labels = np.array(labels).reshape(-1, 1)
        r.close()
        logger.info('Getting labeled data from {} done.'.format(data_uri))

        self._size = self._labels.size
        self._variables = ('x', 'y')
        if rng is None:
            rng = np.random.RandomState(313)
        self.rng = rng
        self.reset()

        # Unlock
        os.close(fd)
        os.unlink(lockfile)
Beispiel #31
0
def main():
    """
    Main script.

    Steps:
    * Get and set context.
    * Load Dataset
    * Initialize DataIterator.
    * Create Networks
    *   Net for Labeled Data
    *   Net for Unlabeled Data
    *   Net for Test Data
    * Create Solver.
    * Training Loop.
    *   Test
    *   Training
    *     by Labeled Data
    *       Calculate Cross Entropy Loss 
    *     by Unlabeled Data
    *       Estimate Adversarial Direction
    *       Calculate LDS Loss
    """

    args = get_args()

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    shape_x = (1, 28, 28)
    n_h = args.n_units
    n_y = args.n_class

    # Load MNist Dataset
    from mnist_data import MnistDataSource
    with MnistDataSource(train=True) as d:
        x_t = d.images
        t_t = d.labels
    with MnistDataSource(train=False) as d:
        x_v = d.images
        t_v = d.labels
    x_t = np.array(x_t / 256.0).astype(np.float32)
    x_t, t_t = x_t[:args.n_train], t_t[:args.n_train]
    x_v, t_v = x_v[:args.n_valid], t_v[:args.n_valid]

    # Create Semi-supervised Datasets
    x_l, t_l, x_u, _ = split_dataset(x_t, t_t, args.n_labeled, args.n_class)
    x_u = np.r_[x_l, x_u]
    x_v = np.array(x_v / 256.0).astype(np.float32)

    # Create DataIterators for datasets of labeled, unlabeled and validation
    di_l = DataIterator(args.batchsize_l, [x_l, t_l])
    di_u = DataIterator(args.batchsize_u, [x_u])
    di_v = DataIterator(args.batchsize_v, [x_v, t_v])

    # Create networks
    # feed-forward-net building function
    def forward(x, test=False):
        return mlp_net(x, n_h, n_y, test)

    # Net for learning labeled data
    xl = nn.Variable((args.batchsize_l,) + shape_x, need_grad=False)
    hl = forward(xl, test=False)
    tl = nn.Variable((args.batchsize_l, 1), need_grad=False)
    loss_l = F.mean(F.softmax_cross_entropy(hl, tl))

    # Net for learning unlabeled data
    xu = nn.Variable((args.batchsize_u,) + shape_x, need_grad=False)
    r = nn.Variable((args.batchsize_u,) + shape_x, need_grad=True)
    eps = nn.Variable((args.batchsize_u,) + shape_x, need_grad=False)
    loss_u, yu = vat(xu, r, eps, forward, distance)

    # Net for evaluating valiation data
    xv = nn.Variable((args.batchsize_v,) + shape_x, need_grad=False)
    hv = forward(xv, test=True)
    tv = nn.Variable((args.batchsize_v, 1), need_grad=False)

    # Create solver
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Monitor trainig and validation stats.
    import nnabla.monitor as M
    monitor = M.Monitor(args.model_save_path)
    monitor_verr = M.MonitorSeries("Test error", monitor, interval=240)
    monitor_time = M.MonitorTimeElapsed("Elapsed time", monitor, interval=240)

    # Training Loop.
    t0 = time.time()

    for i in range(args.max_iter):

        # Validation Test
        if i % args.val_interval == 0:
            n_error = calc_validation_error(
                di_v, xv, tv, hv, args.val_iter)
            monitor_verr.add(i, n_error)

        #################################
        ## Training by Labeled Data #####
        #################################

        # input minibatch of labeled data into variables
        xl.d, tl.d = di_l.next()

        # initialize gradients
        solver.zero_grad()

        # forward, backward and update
        loss_l.forward(clear_no_need_grad=True)
        loss_l.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()

        #################################
        ## Training by Unlabeled Data ###
        #################################

        # input minibatch of unlabeled data into variables
        xu.d, = di_u.next()

        ##### Calculate Adversarial Noise #####

        # Sample random noise
        n = np.random.normal(size=xu.shape).astype(np.float32)

        # Normalize noise vector and input to variable
        r.d = get_direction(n)

        # Set xi, the power-method scaling parameter.
        eps.data.fill(args.xi_for_vat)

        # Calculate y without noise, only once.
        yu.forward(clear_buffer=True)

        # Do power method iteration
        for k in range(args.n_iter_for_power_method):
            # Initialize gradient to receive value
            r.grad.zero()

            # forward, backward, without update
            loss_u.forward(clear_no_need_grad=True)
            loss_u.backward(clear_buffer=True)

            # Normalize gradinet vector and input to variable
            r.d = get_direction(r.g)

        ##### Calculate loss for unlabeled data #####

        # Clear remained gradients
        solver.zero_grad()

        # Set epsilon, the adversarial noise scaling parameter.
        eps.data.fill(args.eps_for_vat)

        # forward, backward and update
        loss_u.forward(clear_no_need_grad=True)
        loss_u.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()

        ##### Learning rate update #####
        if i % args.iter_per_epoch == 0:
            solver.set_learning_rate(
                solver.learning_rate() * args.learning_rate_decay)
        monitor_time.add(i)

    # Evaluate the final model by the error rate with validation dataset
    valid_error = calc_validation_error(di_v, xv, tv, hv, args.val_iter)
    monitor_verr.add(i, valid_error)
    monitor_time.add(i)

    # Save the model.
    nnp_file = os.path.join(
        args.model_save_path, 'vat_%06d.nnp' % args.max_iter)
    runtime_contents = {
        'networks': [
            {'name': 'Validation',
             'batch_size': args.batchsize_v,
             'outputs': {'y': hv},
             'names': {'x': xv}}],
        'executors': [
            {'name': 'Runtime',
             'network': 'Validation',
             'data': ['x'],
             'output': ['y']}]}
    save.save(nnp_file, runtime_contents)

    from cpp_forward_check import check_cpp_forward
    check_cpp_forward(args.model_save_path, [xv.d], [xv], hv, nnp_file)
Beispiel #32
0
def main():
    args = get_args()
    # Get context
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)
    nn.set_auto_forward(True)

    image = io.imread(args.test_image)
    if image.ndim == 2:
        image = color.gray2rgb(image)
    elif image.shape[-1] == 4:
        image = image[..., :3]

    if args.context == 'cudnn':
        if not os.path.isfile(args.cnn_face_detction_model):
            # Block of bellow code will download the cnn based face-detection model file provided by dlib for face detection
            # and will save it in the directory where this script is executed.
            print("Downloading the face detection CNN. Please wait...")
            url = "http://dlib.net/files/mmod_human_face_detector.dat.bz2"
            from nnabla.utils.data_source_loader import download
            download(url, url.split('/')[-1], False)
            # get the decompressed data.
            data = bz2.BZ2File(url.split('/')[-1]).read()
            # write to dat file.
            open(url.split('/')[-1][:-4], 'wb').write(data)
        face_detector = dlib.cnn_face_detection_model_v1(
            args.cnn_face_detction_model)
        detected_faces = face_detector(
            cv2.cvtColor(image[..., ::-1].copy(), cv2.COLOR_BGR2GRAY))
        detected_faces = [[
            d.rect.left(),
            d.rect.top(),
            d.rect.right(),
            d.rect.bottom()
        ] for d in detected_faces]
    else:
        face_detector = dlib.get_frontal_face_detector()
        detected_faces = face_detector(
            cv2.cvtColor(image[..., ::-1].copy(), cv2.COLOR_BGR2GRAY))
        detected_faces = [[d.left(), d.top(),
                           d.right(), d.bottom()] for d in detected_faces]

    if len(detected_faces) == 0:
        print("Warning: No faces were detected.")
        return None

    # Load FAN weights
    with nn.parameter_scope("FAN"):
        print("Loading FAN weights...")
        nn.load_parameters(args.model)

    # Load ResNetDepth weights
    if args.landmarks_type_3D:
        with nn.parameter_scope("ResNetDepth"):
            print("Loading ResNetDepth weights...")
            nn.load_parameters(args.resnet_depth_model)

    landmarks = []
    for i, d in enumerate(detected_faces):
        center = [d[2] - (d[2] - d[0]) / 2.0, d[3] - (d[3] - d[1]) / 2.0]
        center[1] = center[1] - (d[3] - d[1]) * 0.12
        scale = (d[2] - d[0] + d[3] - d[1]) / args.reference_scale
        inp = crop(image, center, scale)
        inp = nn.Variable.from_numpy_array(inp.transpose((2, 0, 1)))
        inp = F.reshape(F.mul_scalar(inp, 1 / 255.0), (1, ) + inp.shape)
        with nn.parameter_scope("FAN"):
            out = fan(inp, args.network_size)[-1]
        pts, pts_img = get_preds_fromhm(out, center, scale)
        pts, pts_img = F.reshape(pts, (68, 2)) * \
            4, F.reshape(pts_img, (68, 2))

        if args.landmarks_type_3D:
            heatmaps = np.zeros((68, 256, 256), dtype=np.float32)
            for i in range(68):
                if pts.d[i, 0] > 0:
                    heatmaps[i] = draw_gaussian(heatmaps[i], pts.d[i], 2)
            heatmaps = nn.Variable.from_numpy_array(heatmaps)
            heatmaps = F.reshape(heatmaps, (1, ) + heatmaps.shape)
            with nn.parameter_scope("ResNetDepth"):
                depth_pred = F.reshape(
                    resnet_depth(F.concatenate(inp, heatmaps, axis=1)),
                    (68, 1))
            pts_img = F.concatenate(pts_img,
                                    depth_pred * (1.0 / (256.0 /
                                                         (200.0 * scale))),
                                    axis=1)

        landmarks.append(pts_img.d)
    visualize(landmarks, image, args.output)
Beispiel #33
0
def save_checkpoint(path, current_iter, solvers):
    """Saves the checkpoint file which contains the params and its state info.

        Args:
            path: Path to the directory the checkpoint file is stored in.
            current_iter: Current iteretion of the training loop.
            solvers: A dictionary about solver's info, which is like;
                     solvers = {"identifier_for_solver_0": solver_0,
                               {"identifier_for_solver_1": solver_1, ...}
                     The keys are used just for state's filenames, so can be anything.
                     Also, you can give a solver object if only one solver exists.
                     Then, the "" is used as an identifier.

        Examples:
            # Create computation graph with parameters.
            pred = construct_pred_net(input_Variable, ...)

            # Create solver and set parameters.
            solver = S.Adam(learning_rate)
            solver.set_parameters(nn.get_parameters())

            # If you have another_solver like,
            # another_solver = S.Sgd(learning_rate)
            # another_solver.set_parameters(nn.get_parameters())

            # Training loop.
            for i in range(start_point, max_iter):
                pred.forward()
                pred.backward()
                solver.zero_grad()
                solver.update()
                save_checkpoint(path, i, solver)

                # If you have another_solver,
                # save_checkpoint(path, i,
                      {"solver": solver, "another_solver": another})

        Notes:
            It generates the checkpoint file (.json) which is like;
            checkpoint_1000 = {
                    "":{
                        "states_path": <path to the states file>
                        "params_names":["conv1/conv/W", ...],
                        "num_update":1000
                       },
                    "current_iter": 1000
                    }

            If you have multiple solvers.
            checkpoint_1000 = {
                    "generator":{
                        "states_path": <path to the states file>,
                        "params_names":["deconv1/conv/W", ...],
                        "num_update":1000
                       },
                    "discriminator":{
                        "states_path": <path to the states file>,
                        "params_names":["conv1/conv/W", ...],
                        "num_update":1000
                       },
                    "current_iter": 1000
                    }

    """

    if isinstance(solvers, nn.solver.Solver):
        solvers = {"": solvers}

    checkpoint_info = dict()

    for solvername, solver_obj in solvers.items():
        prefix = "{}_".format(solvername.replace(
            "/", "_")) if solvername else ""
        partial_info = dict()

        # save solver states.
        states_fname = prefix + 'states_{}.h5'.format(current_iter)
        states_path = os.path.join(path, states_fname)
        solver_obj.save_states(states_path)
        partial_info["states_path"] = states_path

        # save registered parameters' name. (just in case)
        params_names = [k for k in solver_obj.get_parameters().keys()]
        partial_info["params_names"] = params_names

        # save the number of solver update.
        num_update = getattr(solver_obj.get_states()[params_names[0]], "t")
        partial_info["num_update"] = num_update

        checkpoint_info[solvername] = partial_info

    # save parameters.
    params_fname = 'params_{}.h5'.format(current_iter)
    params_path = os.path.join(path, params_fname)
    nn.parameter.save_parameters(params_path)
    checkpoint_info["params_path"] = params_path
    checkpoint_info["current_iter"] = current_iter

    checkpoint_fname = 'checkpoint_{}.json'.format(current_iter)
    filename = os.path.join(path, checkpoint_fname)

    with open(filename, 'w') as f:
        json.dump(checkpoint_info, f)

    logger.info("Checkpoint save (.json): {}".format(filename))

    return
Beispiel #34
0
    def create(self,
               output_cache_dirname,
               normalize=True,
               cache_file_name_prefix='cache'):

        self._normalize = normalize
        self._cache_file_name_prefix = cache_file_name_prefix
        self._cache_dir = output_cache_dirname

        self._cache_file_format = nnabla_config.get('DATA_ITERATOR',
                                                    'cache_file_format')
        logger.info('Cache file format is {}'.format(self._cache_file_format))

        progress(None)

        csv_position_and_data = []
        csv_row = []
        for _position in range(self._size):
            csv_row.append(self._csv_data[self._order[_position]])
            if len(csv_row) == self._cache_size:
                csv_position_and_data.append((_position, csv_row))
                csv_row = []
        if len(csv_row):
            csv_position_and_data.append((self._size - 1, csv_row))

        progress('Create cache', 0)
        with closing(ThreadPool(processes=self._num_of_threads)) as pool:
            cache_index_rows = pool.map(self._save_cache,
                                        csv_position_and_data)
        progress('Create cache', 1.0)

        # Create Index
        index_filename = os.path.join(output_cache_dirname, "cache_index.csv")
        with open(index_filename, 'w') as f:
            writer = csv.writer(f, lineterminator='\n')
            for row in cache_index_rows:
                if row:
                    # row: (file_path, data_nums)
                    writer.writerow((os.path.basename(row[0]), row[1]))

        # Create Info
        if self._cache_file_format == ".npy":
            info_filename = os.path.join(output_cache_dirname,
                                         "cache_info.csv")
            with open(info_filename, 'w') as f:
                writer = csv.writer(f, lineterminator='\n')
                for variable in self._variables:
                    writer.writerow((variable, ))

        # Create original.csv
        if self._original_source_uri is not None:
            shutil.copy(self._original_source_uri,
                        os.path.join(output_cache_dirname, "original.csv"))

        # Create order.csv
        if self._order is not None and \
                self._original_order is not None:
            with open(os.path.join(output_cache_dirname, "order.csv"),
                      'w') as o:
                writer = csv.writer(o, lineterminator='\n')
                for orders in zip(self._original_order, self._order):
                    writer.writerow(list(orders))
Beispiel #35
0
def load_snapshot(load_dir: str,
                  file_name: str = "pointnet_classification.h5") -> None:
    logger.info("Load network parameters")
    model_file_path = os.path.join(load_dir, file_name)
    nn.load_parameters(path=model_file_path)
Beispiel #36
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('input',
                        type=str,
                        nargs='+',
                        help='Source file or directory.')
    parser.add_argument('output', type=str, help='Destination directory.')
    parser.add_argument('-W',
                        '--width',
                        type=int,
                        default=320,
                        help='width of output image (default:320)')
    parser.add_argument('-H',
                        '--height',
                        type=int,
                        default=320,
                        help='height of output image (default:320)')
    parser.add_argument(
        '-m',
        '--mode',
        default='trimming',
        choices=['trimming', 'padding'],
        help='shaping mode (trimming or padding)  (default:trimming)')
    parser.add_argument(
        '-S',
        '--shuffle',
        choices=['True', 'False'],
        help='shuffle mode if not specified, train:True, val:False.' +
        ' Otherwise specified value will be used for both.')
    parser.add_argument('-N',
                        '--file-cache-size',
                        type=int,
                        default=100,
                        help='num of data in cache file (default:100)')
    parser.add_argument('-C',
                        '--cache-type',
                        default='npy',
                        choices=['h5', 'npy'],
                        help='cache format (h5 or npy) (default:npy)')
    parser.add_argument('--thinning',
                        type=int,
                        default=1,
                        help='Thinning rate')

    args = parser.parse_args()
    ############################################################################
    # Analyze tar
    # If it consists only of members corresponding to regular expression
    # 'n[0-9]{8}\.tar', it is judged as train data archive.
    # If it consists only of members corresponding to regular expression
    # 'ILSVRC2012_val_[0-9]{8}\.JPEG', it is judged as validation data archive.

    archives = {'train': None, 'val': None}
    for inputarg in args.input:
        print('Checking input file [{}]'.format(inputarg))
        archive = tarfile.open(inputarg)
        is_train = False
        is_val = False
        names = []
        for name in archive.getnames():
            if re.match(r'n[0-9]{8}\.tar', name):
                if is_val:
                    print('Train data {} includes in validation tar'.format(
                        name))
                    exit(-1)
                is_train = True
            elif re.match(r'ILSVRC2012_val_[0-9]{8}\.JPEG', name):
                if is_train:
                    print('Validation data {} includes in train tar'.format(
                        name))
                    exit(-1)
                is_val = True
            else:
                print('Invalid member {} includes in tar file'.format(name))
                exit(-1)
            names.append(name)
        if is_train:
            if archives['train'] is None:
                archives['train'] = (archive, names)
            else:
                print('Please specify only 1 training tar archive.')
                exit(-1)
        if is_val:
            if archives['val'] is None:
                archives['val'] = (archive, names)
            else:
                print('Please specify only 1 validation tar archive.')
                exit(-1)

    # Read label of validation data, (Use ascending label of wordnet_id)
    validation_ground_truth = []
    g_file = VALIDATION_DATA_LABEL
    with open(g_file, 'r') as f:
        for l in f.readlines():
            validation_ground_truth.append(int(l.rstrip()))

    ############################################################################
    # Prepare logging
    tmpdir = tempfile.mkdtemp()
    logfilename = os.path.join(tmpdir, 'nnabla.log')

    # Temporarily chdir to tmpdir just before importing nnabla to reflect nnabla.conf.
    cwd = os.getcwd()
    os.chdir(tmpdir)
    with open('nnabla.conf', 'w') as f:
        f.write('[LOG]\n')
        f.write('log_file_name = {}\n'.format(logfilename))
        f.write('log_file_format = %(funcName)s : %(message)s\n')
        f.write('log_console_level = CRITICAL\n')

    from nnabla.config import nnabla_config
    os.chdir(cwd)

    ############################################################################
    # Data iterator setting
    nnabla_config.set('DATA_ITERATOR', 'cache_file_format',
                      '.' + args.cache_type)
    nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size',
                      str(args.file_cache_size))
    nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_num_of_threads',
                      '1')

    if not os.path.isdir(args.output):
        os.makedirs(args.output)

    ############################################################################
    # Prepare status monitor
    from nnabla.utils.progress import configure_progress
    configure_progress(None, _progress)

    ############################################################################
    # Converter

    try:
        if archives['train'] is not None:
            from nnabla.logger import logger
            logger.info('StartCreatingCache')
            archive, names = archives['train']
            output = os.path.join(args.output, 'train')
            if not os.path.isdir(output):
                os.makedirs(output)
            _create_train_cache(archive, output, names, args)
        if archives['val'] is not None:
            from nnabla.logger import logger
            logger.info('StartCreatingCache')
            archive, names = archives['val']
            output = os.path.join(args.output, 'val')
            if not os.path.isdir(output):
                os.makedirs(output)
            _create_validation_cache(archive, output, names,
                                     validation_ground_truth, args)
    except KeyboardInterrupt:
        shutil.rmtree(tmpdir, ignore_errors=True)

        # Even if CTRL-C is pressed, it does not stop if there is a running
        # thread, so it sending a signal to itself.
        os.kill(os.getpid(), 9)

    ############################################################################
    # Finish
    _finish = True
    shutil.rmtree(tmpdir, ignore_errors=True)
Beispiel #37
0
def save_snapshot(save_dir: str) -> None:
    logger.info("Save network parameters")
    os.makedirs(save_dir, exist_ok=True)
    model_file_path = os.path.join(save_dir, "pointnet_classification.h5")
    nn.save_parameters(path=model_file_path)
Beispiel #38
0
def save(filename, contents, include_params=False, variable_batch_size=True):
    '''Save network definition, inference/training execution
    configurations etc.

    Args:
        filename (str): Filename to store information. The file
            extension is used to determine the saving file format.
            ``.nnp``: (Recommended) Creating a zip archive with nntxt (network
            definition etc.) and h5 (parameters).
            ``.nntxt``: Protobuf in text format.
            ``.protobuf``: Protobuf in binary format (unsafe in terms of
             backward compatibility).
        contents (dict): Information to store.
        include_params (bool): Includes parameter into single file. This is
            ignored when the extension of filename is nnp.
        variable_batch_size (bool):
            By ``True``, the first dimension of all variables is considered
            as batch size, and left as a placeholder
            (more specifically ``-1``). The placeholder dimension will be
            filled during/after loading.

    Example:
        The following example creates a two inputs and two
        outputs MLP, and save the network structure and the initialized
        parameters.

        .. code-block:: python

            import nnabla as nn
            import nnabla.functions as F
            import nnabla.parametric_functions as PF
            from nnabla.utils.save import save

            batch_size = 16
            x0 = nn.Variable([batch_size, 100])
            x1 = nn.Variable([batch_size, 100])
            h1_0 = PF.affine(x0, 100, name='affine1_0')
            h1_1 = PF.affine(x1, 100, name='affine1_0')
            h1 = F.tanh(h1_0 + h1_1)
            h2 = F.tanh(PF.affine(h1, 50, name='affine2'))
            y0 = PF.affine(h2, 10, name='affiney_0')
            y1 = PF.affine(h2, 10, name='affiney_1')

            contents = {
                'networks': [
                    {'name': 'net1',
                     'batch_size': batch_size,
                     'outputs': {'y0': y0, 'y1': y1},
                     'names': {'x0': x0, 'x1': x1}}],
                'executors': [
                    {'name': 'runtime',
                     'network': 'net1',
                     'data': ['x0', 'x1'],
                     'output': ['y0', 'y1']}]}
            save('net.nnp', contents)


        To get a trainable model, use following code instead.

        .. code-block:: python

            contents = {
            'global_config': {'default_context': ctx},
            'training_config':
                {'max_epoch': args.max_epoch,
                 'iter_per_epoch': args_added.iter_per_epoch,
                 'save_best': True},
            'networks': [
                {'name': 'training',
                 'batch_size': args.batch_size,
                 'outputs': {'loss': loss_t},
                 'names': {'x': x, 'y': t, 'loss': loss_t}},
                {'name': 'validation',
                 'batch_size': args.batch_size,
                 'outputs': {'loss': loss_v},
                 'names': {'x': x, 'y': t, 'loss': loss_v}}],
            'optimizers': [
                {'name': 'optimizer',
                 'solver': solver,
                 'network': 'training',
                 'dataset': 'mnist_training',
                 'weight_decay': 0,
                 'lr_decay': 1,
                 'lr_decay_interval': 1,
                 'update_interval': 1}],
            'datasets': [
                {'name': 'mnist_training',
                 'uri': 'MNIST_TRAINING',
                 'cache_dir': args.cache_dir + '/mnist_training.cache/',
                 'variables': {'x': x, 'y': t},
                 'shuffle': True,
                 'batch_size': args.batch_size,
                 'no_image_normalization': True},
                {'name': 'mnist_validation',
                 'uri': 'MNIST_VALIDATION',
                 'cache_dir': args.cache_dir + '/mnist_test.cache/',
                 'variables': {'x': x, 'y': t},
                 'shuffle': False,
                 'batch_size': args.batch_size,
                 'no_image_normalization': True
                 }],
            'monitors': [
                {'name': 'training_loss',
                 'network': 'validation',
                 'dataset': 'mnist_training'},
                {'name': 'validation_loss',
                 'network': 'validation',
                 'dataset': 'mnist_validation'}],
            }


    '''
    _, ext = os.path.splitext(filename)
    if ext == '.nntxt' or ext == '.prototxt':
        logger.info("Saving {} as prototxt".format(filename))
        proto = create_proto(contents, include_params, variable_batch_size)
        with open(filename, 'w') as file:
            text_format.PrintMessage(proto, file)
    elif ext == '.protobuf':
        logger.info("Saving {} as protobuf".format(filename))
        proto = create_proto(contents, include_params, variable_batch_size)
        with open(filename, 'wb') as file:
            file.write(proto.SerializeToString())
    elif ext == '.nnp':
        logger.info("Saving {} as nnp".format(filename))
        try:
            tmpdir = tempfile.mkdtemp()
            save('{}/network.nntxt'.format(tmpdir),
                 contents,
                 include_params=False,
                 variable_batch_size=variable_batch_size)

            with open('{}/nnp_version.txt'.format(tmpdir), 'w') as file:
                file.write('{}\n'.format(nnp_version()))

            save_parameters('{}/parameter.protobuf'.format(tmpdir))

            with zipfile.ZipFile(filename, 'w') as nnp:
                nnp.write('{}/nnp_version.txt'.format(tmpdir),
                          'nnp_version.txt')
                nnp.write('{}/network.nntxt'.format(tmpdir), 'network.nntxt')
                nnp.write('{}/parameter.protobuf'.format(tmpdir),
                          'parameter.protobuf')
        finally:
            shutil.rmtree(tmpdir)
Beispiel #39
0
def save(filename, contents, include_params=False):
    '''Save network definition, inference/training execution
    configurations etc.

    Args:
        filename (str): Filename to store infomation. The file
            extension is used to determine the saving file format.
            ``.nnp``: (Recomended) Creating a zip archive with nntxt (network
            definition etc.) and h5 (parameters).
            ``.nntxt``: Protobuf in text format.
            ``.protobuf'': Protobuf in binary format (unsafe in terms of
             backward compatibility).
        contents (dict): Information to store.
        include_params (bool): Includes parameter into single file. This is
            ignored when the extension of filename is nnp.

    Example:
        The current supported fields as contents are ``networks`` and
        ``executors``. The following example creates a two inputs and two
        outputs MLP, and save the network structure and the initialized
        parameters.:: python

            import nnabla as nn
            import nnabla.functions as F
            import nnabla.parametric_functions as PF

            x0 = nn.Variable([batch_size, 100])
            x1 = nn.Variable([batch_size, 100])
            h1_0 = PF.affine(x0, 100, name='affine1_0')
            h1_1 = PF.affine(x1, 100, name='affine1_0')
            h1 = F.tanh(h1_0 + h1_1)
            h2 = F.tanh(PF.affine(h1, 50, name='affine2'))
            y0 = PF.affine(h2, 10, name='affiney_0')
            y1 = PF.affine(h2, 10, name='affiney_1')

            contents = {
                'networks': [
                    {'name': 'net1',
                     'batch_size': batch_size,
                     'outputs': {'y0': y0, 'y1': y1},
                     'names': {'x0': x0, 'x1': x1}}],
                'executors': [
                    {'name': 'runtime',
                     'network': 'net1',
                     'data': ['x0', 'x1'],
                     'output': ['y0', 'y1']}]}
            save('net.nnp', contents)
    '''
    _, ext = os.path.splitext(filename)
    print(filename, ext)
    if ext == '.nntxt' or ext == '.prototxt':
        logger.info("Saveing {} as prototxt".format(filename))
        proto = create_proto(contents, include_params)
        with open(filename, 'w') as file:
            text_format.PrintMessage(proto, file)
    elif ext == '.protobuf':
        logger.info("Saveing {} as protobuf".format(filename))
        proto = create_proto(contents, include_params)
        with open(filename, 'wb') as file:
            file.write(proto.SerializeToString())
    elif ext == '.nnp':
        logger.info("Saveing {} as nnp".format(filename))
        tmpdir = tempfile.mkdtemp()
        save('{}/network.nntxt'.format(tmpdir), contents, include_params=False)
        save_parameters('{}/parameter.protobuf'.format(tmpdir))
        with zipfile.ZipFile(filename, 'w') as nnp:
            nnp.write('{}/network.nntxt'.format(tmpdir), 'network.nntxt')
            nnp.write('{}/parameter.protobuf'.format(tmpdir),
                      'parameter.protobuf')
        shutil.rmtree(tmpdir)
Beispiel #40
0
def load(filenames,
         prepare_data_iterator=True,
         batch_size=None,
         exclude_parameter=False,
         parameter_only=False):
    '''load
    Load network information from files.

    Args:
        filenames (list): List of filenames.
    Returns:
        dict: Network information.
    '''
    class Info:
        pass

    info = Info()

    proto = nnabla_pb2.NNablaProtoBuf()
    for filename in filenames:
        _, ext = os.path.splitext(filename)

        # TODO: Here is some known problems.
        #   - Even when protobuf file includes network structure,
        #     it will not loaded.
        #   - Even when prototxt file includes parameter,
        #     it will not loaded.

        if ext in ['.nntxt', '.prototxt']:
            if not parameter_only:
                with open(filename, 'rt') as f:
                    try:
                        text_format.Merge(f.read(), proto)
                    except:
                        logger.critical('Failed to read {}.'.format(filename))
                        logger.critical(
                            '2 byte characters may be used for file name or folder name.'
                        )
                        raise
            if len(proto.parameter) > 0:
                if not exclude_parameter:
                    nn.load_parameters(filename)
        elif ext in ['.protobuf', '.h5']:
            if not exclude_parameter:
                nn.load_parameters(filename)
            else:
                logger.info('Skip loading parameter.')

        elif ext == '.nnp':
            try:
                tmpdir = tempfile.mkdtemp()
                with zipfile.ZipFile(filename, 'r') as nnp:
                    for name in nnp.namelist():
                        _, ext = os.path.splitext(name)
                        if name == 'nnp_version.txt':
                            nnp.extract(name, tmpdir)
                            with open(os.path.join(tmpdir, name), 'rt') as f:
                                pass  # TODO currently do nothing with version.
                        elif ext in ['.nntxt', '.prototxt']:
                            nnp.extract(name, tmpdir)
                            if not parameter_only:
                                with open(os.path.join(tmpdir, name),
                                          'rt') as f:
                                    text_format.Merge(f.read(), proto)
                            if len(proto.parameter) > 0:
                                if not exclude_parameter:
                                    nn.load_parameters(
                                        os.path.join(tmpdir, name))
                        elif ext in ['.protobuf', '.h5']:
                            nnp.extract(name, tmpdir)
                            if not exclude_parameter:
                                nn.load_parameters(os.path.join(tmpdir, name))
                            else:
                                logger.info('Skip loading parameter.')
            finally:
                shutil.rmtree(tmpdir)

    default_context = None
    if proto.HasField('global_config'):
        info.global_config = _global_config(proto)
        default_context = info.global_config.default_context
        if 'cuda' in default_context.backend:
            import nnabla_ext.cudnn
        elif 'cuda:float' in default_context.backend:
            try:
                import nnabla_ext.cudnn
            except:
                pass
    else:
        import nnabla_ext.cpu
        default_context = nnabla_ext.cpu.context()

    comm = current_communicator()
    if comm:
        default_context.device_id = str(comm.rank)
    if proto.HasField('training_config'):
        info.training_config = _training_config(proto)

    info.datasets = _datasets(
        proto, prepare_data_iterator if prepare_data_iterator is not None else
        info.training_config.max_epoch > 0)

    info.networks = _networks(proto, default_context, batch_size)

    info.optimizers = _optimizers(proto, default_context, info.networks,
                                  info.datasets)

    info.monitors = _monitors(proto, default_context, info.networks,
                              info.datasets)

    info.executors = _executors(proto, info.networks)

    return info
Beispiel #41
0
def load_parameters(path, proto=None, needs_proto=False, extension=".nntxt"):
    """Load parameters from a file with the specified format.

    Args:
      path : path or file object
    """
    if isinstance(path, str):
        _, ext = os.path.splitext(path)
    else:
        ext = extension

    if ext == '.h5':
        # TODO temporary work around to suppress FutureWarning message.
        import warnings
        warnings.simplefilter('ignore', category=FutureWarning)
        import h5py
        with get_file_handle_load(path, ext) as hd:
            keys = []

            def _get_keys(name):
                ds = hd[name]
                if not isinstance(ds, h5py.Dataset):
                    # Group
                    return
                # To preserve order of parameters
                keys.append((ds.attrs.get('index', None), name))
            hd.visit(_get_keys)
            for _, key in sorted(keys):
                ds = hd[key]

                var = get_parameter_or_create(
                    key, ds.shape, need_grad=ds.attrs['need_grad'])
                var.data.cast(ds.dtype)[...] = ds[...]

                if needs_proto:
                    if proto is None:
                        proto = nnabla_pb2.NNablaProtoBuf()
                    parameter = proto.parameter.add()
                    parameter.variable_name = key
                    parameter.shape.dim.extend(ds.shape)
                    parameter.data.extend(
                        numpy.array(ds[...]).flatten().tolist())
                    parameter.need_grad = False
                    if ds.attrs['need_grad']:
                        parameter.need_grad = True

    else:
        if proto is None:
            proto = nnabla_pb2.NNablaProtoBuf()

        if ext == '.protobuf':
            with get_file_handle_load(path, ext) as f:
                proto.MergeFromString(f.read())
                set_parameter_from_proto(proto)
        elif ext == '.nntxt' or ext == '.prototxt':
            with get_file_handle_load(path, ext) as f:
                text_format.Merge(f.read(), proto)
                set_parameter_from_proto(proto)

        elif ext == '.nnp':
            try:
                tmpdir = tempfile.mkdtemp()
                with get_file_handle_load(path, ext) as nnp:
                    for name in nnp.namelist():
                        nnp.extract(name, tmpdir)
                        _, ext = os.path.splitext(name)
                        if ext in ['.protobuf', '.h5']:
                            proto = load_parameters(os.path.join(
                                tmpdir, name), proto, needs_proto)
            finally:
                shutil.rmtree(tmpdir)
                logger.info("Parameter load ({}): {}".format(format, path))
        else:
            logger.error("Invalid parameter file '{}'".format(path))
    return proto
Beispiel #42
0
    def create(self,
               output_cache_dirname,
               normalize=True,
               cache_file_name_prefix='cache'):

        self._normalize = normalize
        self._cache_file_name_prefix = cache_file_name_prefix

        self._cache_file_format = nnabla_config.get('DATA_ITERATOR',
                                                    'cache_file_format')
        logger.info('Cache file format is {}'.format(self._cache_file_format))

        self._cache_dir = output_cache_dirname

        progress(None)

        self._cache_file_order = []
        self._cache_file_data_orders = []
        self._cache_file_names = []

        self._cache_data = []
        progress('Create cache', 0)
        last_time = time.time()
        for self._position in range(self._size):
            if time.time() >= last_time + 1.0:
                progress('Create cache', self._position / self._size)
                last_time = time.time()
            self._file.seek(self._line_positions[self._order[self._position]])
            line = self._file.readline().decode('utf-8')
            csvreader = csv.reader([line])
            row = next(csvreader)
            self._cache_data.append(tuple(self._process_row(row)))

            if len(self._cache_data) >= self._cache_size:
                self._save_cache()
                self._cache_data = []

        self._save_cache()
        progress('Create cache', 1.0)

        # Adjust data size into reseted position. In most case it means
        # multiple of bunch(mini-batch) size.
        num_of_cache_files = int(
            numpy.ceil(float(self._size) / self._cache_size))
        self._cache_file_order = self._cache_file_order[0:num_of_cache_files]
        self._cache_file_data_orders = self._cache_file_data_orders[
            0:num_of_cache_files]
        if self._size % self._cache_size != 0:
            self._cache_file_data_orders[num_of_cache_files -
                                         1] = self._cache_file_data_orders[
                                             num_of_cache_files -
                                             1][0:self._size %
                                                self._cache_size]

        # Create Index
        index_filename = os.path.join(self._cache_dir, "cache_index.csv")
        with open(index_filename, 'w') as f:
            writer = csv.writer(f, lineterminator='\n')
            for fn, orders in zip(self._cache_file_names,
                                  self._cache_file_data_orders):
                writer.writerow((os.path.basename(fn), len(orders)))
        # Create Info
        if self._cache_file_format == ".npy":
            info_filename = os.path.join(self._cache_dir, "cache_info.csv")
            with open(info_filename, 'w') as f:
                writer = csv.writer(f, lineterminator='\n')
                for variable in self._variables:
                    writer.writerow((variable, ))

        # Create original.csv
        if self._original_source_uri is not None:
            shutil.copy(self._original_source_uri,
                        os.path.join(self._cache_dir, "original.csv"))

        # Create order.csv
        if self._order is not None and \
                self._original_order is not None:
            with open(os.path.join(self._cache_dir, "order.csv"), 'w') as o:
                writer = csv.writer(o, lineterminator='\n')
                for orders in zip(self._original_order, self._order):
                    writer.writerow(list(orders))
Beispiel #43
0
        channel_first (bool): If True, you can input the image whose shape is (channel, height, width).
    """

    module.imsave(path, img, channel_first=channel_first)


def imresize(img, size, interpolate="bilinear", channel_first=False):
    """
    Resize ``img`` to ``size``.
    As default, the shape of input image has to be (height, width, channel).

    Args:
        img (numpy.ndarray): Input image.
        size (tuple of int): Output shape. The order is (width, height).
        channel_first (bool): If True, the shape of the output array is (channel, height, width) for RGB image. Default is False.
    Returns:
         numpy.ndarray
    """

    return module.imresize(img,
                           size,
                           interpolate=interpolate,
                           channel_first=channel_first)


# alias
imwrite = imsave
imload = imread

logger.info("use {} for the backend of image utils".format(backend))
Beispiel #44
0
def main():
    args = get_args()
    rng = np.random.RandomState(1223)

    # Get context
    from nnabla.ext_utils import get_extension_context, import_extension_module
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)
    ext = import_extension_module(args.context)

    # read label file
    f = open(args.label_file_path, "r")
    labels_dict = f.readlines()

    # Load parameters
    _ = nn.load_parameters(args.model_load_path)

    # Build a Deeplab v3+ network
    x = nn.Variable((1, 3, args.image_height, args.image_width),
                    need_grad=False)
    y = net.deeplabv3plus_model(x,
                                args.output_stride,
                                args.num_class,
                                test=True)

    # preprocess image
    image = imageio.imread(args.test_image_file, as_gray=False, pilmode="RGB")
    #image = imread(args.test_image_file).astype('float32')
    orig_h, orig_w, orig_c = image.shape
    old_size = (orig_h, orig_w)

    input_array = image_preprocess.preprocess_image_and_label(
        image,
        label=None,
        target_width=args.image_width,
        target_height=args.image_height,
        train=False)
    print('Input', input_array.shape)
    input_array = np.transpose(input_array, (2, 0, 1))
    input_array = np.reshape(
        input_array,
        (1, input_array.shape[0], input_array.shape[1], input_array.shape[2]))

    # Compute inference and inference time
    t = time.time()

    x.d = input_array
    y.forward(clear_buffer=True)
    print("done")
    available_devices = ext.get_devices()
    ext.device_synchronize(available_devices[0])
    ext.clear_memory_cache()

    elapsed = time.time() - t
    print('Inference time : %s seconds' % (elapsed))

    output = np.argmax(y.d, axis=1)  # (batch,h,w)

    # Apply post processing
    post_processed = post_process(output[0], old_size,
                                  (args.image_height, args.image_width))

    # Get the classes predicted
    predicted_classes = np.unique(post_processed)
    for i in range(predicted_classes.shape[0]):
        print('Classes Segmented: ', labels_dict[predicted_classes[i]])

    # Visualize inference result
    visualize(post_processed)
Beispiel #45
0
def load_checkpoint(path, solvers):
    """Given the checkpoint file, loads the parameters and solver states.

        Args:
            path: Path to the checkpoint file.
            solvers: A dictionary about solver's info, which is like;
                     solvers = {"identifier_for_solver_0": solver_0,
                               {"identifier_for_solver_1": solver_1, ...}
                     The keys are used for retrieving proper info from the checkpoint.
                     so must be the same as the one used when saved.
                     Also, you can give a solver object if only one solver exists.
                     Then, the "" is used as an identifier.

        Returns:
            current_iter: The number of iteretions that the training resumes from.
                          Note that this assumes that the numbers of the update for
                          each solvers is the same.

        Examples:
            # Create computation graph with parameters.
            pred = construct_pred_net(input_Variable, ...)

            # Create solver and set parameters.
            solver = S.Adam(learning_rate)
            solver.set_parameters(nn.get_parameters())

            # AFTER setting parameters.
            start_point = load_checkpoint(path, solver)

            # Training loop.

        Notes:
            It requires the checkpoint file. For details, refer to save_checkpoint;
            checkpoint_1000 = {
                    "":{
                        "states_path": <path to the states file>
                        "params_names":["conv1/conv/W", ...],
                        "num_update":1000
                       },
                    "current_iter": 1000
                    }

            If you have multiple solvers.
            checkpoint_1000 = {
                    "generator":{
                        "states_path": <path to the states file>,
                        "params_names":["deconv1/conv/W", ...],
                        "num_update":1000
                       },
                    "discriminator":{
                        "states_path": <path to the states file>,
                        "params_names":["conv1/conv/W", ...],
                        "num_update":1000
                       },
                    "current_iter": 1000
                    }

    """

    assert os.path.isfile(path), "checkpoint file not found"

    with open(path, 'r') as f:
        checkpoint_info = json.load(f)

    if isinstance(solvers, nn.solver.Solver):
        solvers = {"": solvers}

    logger.info("Checkpoint load (.json): {}".format(path))

    # load parameters (stored in global).
    params_path = checkpoint_info["params_path"]
    assert os.path.isfile(params_path), "parameters file not found."

    nn.parameter.load_parameters(params_path)

    for solvername, solver_obj in solvers.items():
        partial_info = checkpoint_info[solvername]
        if set(solver_obj.get_parameters().keys()) != set(partial_info["params_names"]):
            logger.warning("Detected parameters do not match.")

        # load solver states.
        states_path = partial_info["states_path"]
        assert os.path.isfile(states_path), "states file not found."

        # set solver states.
        solver_obj.load_states(states_path)

    # get current iteration. note that this might differ from the numbers of update.
    current_iter = checkpoint_info["current_iter"]

    return current_iter