Example #1
0
def generate_case_from_nntxt_str(nntxt_str,
                                 param_format,
                                 dataset_sample_num,
                                 batch_size=None):
    proto = proto_from_str(nntxt_str)
    with generate_csv_png(dataset_sample_num,
                          get_input_size(proto)) as dataset_csv_file:
        # To test dataset, we create a randomly generated dataset.
        for ds in proto.dataset:
            ds.batch_size = batch_size if batch_size else ds.batch_size
            ds.uri = dataset_csv_file
            ds.cache_dir = os.path.join(os.path.dirname(dataset_csv_file),
                                        "data.cache")
        nntxt_io = io.StringIO()
        text_format.PrintMessage(proto, nntxt_io)
        nntxt_io.seek(0)

        version = io.StringIO()
        version.write('{}\n'.format(nnp_version()))
        version.seek(0)

        param = io.BytesIO()
        prepare_parameters(nntxt_str)
        nn.parameter.save_parameters(param, extension=param_format)

        with create_temp_with_dir(NNP_FILE) as temp_nnp_file_name:
            with get_file_handle_save(temp_nnp_file_name, ".nnp") as nnp:
                nnp.writestr('nnp_version.txt', version.read())
                nnp.writestr('network.nntxt', nntxt_io.read())
                nnp.writestr('parameter{}'.format(param_format), param.read())
            yield temp_nnp_file_name
Example #2
0
def save_parameters(path, params=None, extension=None):
    """Save all parameters into a file with the specified format.

    Currently hdf5 and protobuf formats are supported.

    Args:
      path : path or file object
      params (dict, optional): Parameters to be saved. Dictionary is of a parameter name (:obj:`str`) to :obj:`~nnabla.Variable`.
    """
    if isinstance(path, str):
        _, ext = os.path.splitext(path)
    else:
        ext = extension
    params = get_parameters(grad_only=False) if params is None else params
    if ext == '.h5':
        # TODO temporary work around to suppress FutureWarning message.
        import warnings
        warnings.simplefilter('ignore', category=FutureWarning)
        import h5py
        with get_file_handle_save(path, ext) as hd:
            for i, (k, v) in enumerate(iteritems(params)):
                hd[k] = v.d
                hd[k].attrs['need_grad'] = v.need_grad
                # To preserve order of parameters
                hd[k].attrs['index'] = i
    elif ext == '.protobuf':
        proto = nnabla_pb2.NNablaProtoBuf()
        for variable_name, variable in params.items():
            parameter = proto.parameter.add()
            parameter.variable_name = variable_name
            parameter.shape.dim.extend(variable.shape)
            parameter.data.extend(numpy.array(variable.d).flatten().tolist())
            parameter.need_grad = variable.need_grad

        with get_file_handle_save(path, ext) as f:
            f.write(proto.SerializeToString())
    else:
        logger.critical('Only supported hdf5 or protobuf.')
        assert False
    logger.info("Parameter save ({}): {}".format(ext, path))
Example #3
0
def test_file_close_exception(extension):
    with pytest.raises(ZeroDivisionError) as excinfo:
        with create_temp_with_dir("tmp{}".format(extension)) as filename:
            with get_file_handle_save(filename, ext=extension) as f:
                file_handler = f
                1 / 0
    assert file_handler.closed

    with pytest.raises(ZeroDivisionError) as excinfo:
        with create_temp_with_dir("tmp{}".format(extension)) as filename:
            # create a file at first
            with open(filename, "w") as f:
                f.write("\n")

            with get_file_handle_load(None, filename, ext=extension) as f:
                file_handler = f
                1 / 0
    assert file_handler.closed
Example #4
0
def save_optimizer_states(filebase, ext, train_config):
    filelist = []
    if ext == '.protobuf':
        filename = filebase + '_optimizer.protobuf.optimizer'
        proto = nnabla_pb2.NNablaProtoBuf()
        proto_optimizers = []
        for o in train_config.optimizers.values():
            proto_optimizers.append(_create_optimizer_lite(o.optimizer))
        proto.optimizer.extend(proto_optimizers)
        with get_file_handle_save(filename, '.protobuf') as f:
            f.write(proto.SerializeToString())
            filelist.append(filename)
    else:
        for o in train_config.optimizers.values():
            f_name = '{}_{}_optimizer.h5'.format(
                o.optimizer.name,
                re.sub(r'(|Cuda)$', '', str(o.optimizer.solver.name)))
            filename = '{}_{}'.format(filebase, f_name)
            o.optimizer.solver.save_states(filename)
            name_ext = '{}.optimizer'.format(filename)
            os.rename(filename, name_ext)
            filelist.append(name_ext)
    return filelist
Example #5
0
def save(filename,
         contents,
         include_params=False,
         variable_batch_size=True,
         extension=".nnp"):
    '''Save network definition, inference/training execution
    configurations etc.

    Args:
        filename (str or file object): Filename to store information. The file
            extension is used to determine the saving file format.
            ``.nnp``: (Recommended) Creating a zip archive with nntxt (network
            definition etc.) and h5 (parameters).
            ``.nntxt``: Protobuf in text format.
            ``.protobuf``: Protobuf in binary format (unsafe in terms of
             backward compatibility).
        contents (dict): Information to store.
        include_params (bool): Includes parameter into single file. This is
            ignored when the extension of filename is nnp.
        variable_batch_size (bool):
            By ``True``, the first dimension of all variables is considered
            as batch size, and left as a placeholder
            (more specifically ``-1``). The placeholder dimension will be
            filled during/after loading.
        extension: if files is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp".

    Example:
        The following example creates a two inputs and two
        outputs MLP, and save the network structure and the initialized
        parameters.

        .. code-block:: python

            import nnabla as nn
            import nnabla.functions as F
            import nnabla.parametric_functions as PF
            from nnabla.utils.save import save

            batch_size = 16
            x0 = nn.Variable([batch_size, 100])
            x1 = nn.Variable([batch_size, 100])
            h1_0 = PF.affine(x0, 100, name='affine1_0')
            h1_1 = PF.affine(x1, 100, name='affine1_0')
            h1 = F.tanh(h1_0 + h1_1)
            h2 = F.tanh(PF.affine(h1, 50, name='affine2'))
            y0 = PF.affine(h2, 10, name='affiney_0')
            y1 = PF.affine(h2, 10, name='affiney_1')

            contents = {
                'networks': [
                    {'name': 'net1',
                     'batch_size': batch_size,
                     'outputs': {'y0': y0, 'y1': y1},
                     'names': {'x0': x0, 'x1': x1}}],
                'executors': [
                    {'name': 'runtime',
                     'network': 'net1',
                     'data': ['x0', 'x1'],
                     'output': ['y0', 'y1']}]}
            save('net.nnp', contents)


        To get a trainable model, use following code instead.

        .. code-block:: python

            contents = {
            'global_config': {'default_context': ctx},
            'training_config':
                {'max_epoch': args.max_epoch,
                 'iter_per_epoch': args_added.iter_per_epoch,
                 'save_best': True},
            'networks': [
                {'name': 'training',
                 'batch_size': args.batch_size,
                 'outputs': {'loss': loss_t},
                 'names': {'x': x, 'y': t, 'loss': loss_t}},
                {'name': 'validation',
                 'batch_size': args.batch_size,
                 'outputs': {'loss': loss_v},
                 'names': {'x': x, 'y': t, 'loss': loss_v}}],
            'optimizers': [
                {'name': 'optimizer',
                 'solver': solver,
                 'network': 'training',
                 'dataset': 'mnist_training',
                 'weight_decay': 0,
                 'lr_decay': 1,
                 'lr_decay_interval': 1,
                 'update_interval': 1}],
            'datasets': [
                {'name': 'mnist_training',
                 'uri': 'MNIST_TRAINING',
                 'cache_dir': args.cache_dir + '/mnist_training.cache/',
                 'variables': {'x': x, 'y': t},
                 'shuffle': True,
                 'batch_size': args.batch_size,
                 'no_image_normalization': True},
                {'name': 'mnist_validation',
                 'uri': 'MNIST_VALIDATION',
                 'cache_dir': args.cache_dir + '/mnist_test.cache/',
                 'variables': {'x': x, 'y': t},
                 'shuffle': False,
                 'batch_size': args.batch_size,
                 'no_image_normalization': True
                 }],
            'monitors': [
                {'name': 'training_loss',
                 'network': 'validation',
                 'dataset': 'mnist_training'},
                {'name': 'validation_loss',
                 'network': 'validation',
                 'dataset': 'mnist_validation'}],
            }


    '''
    if isinstance(filename, str):
        _, ext = os.path.splitext(filename)
    else:
        ext = extension
    if ext == '.nntxt' or ext == '.prototxt':
        logger.info("Saving {} as prototxt".format(filename))
        proto = create_proto(contents, include_params, variable_batch_size)
        with get_file_handle_save(filename, ext) as file:
            text_format.PrintMessage(proto, file)
    elif ext == '.protobuf':
        logger.info("Saving {} as protobuf".format(filename))
        proto = create_proto(contents, include_params, variable_batch_size)
        with get_file_handle_save(filename, ext) as file:
            file.write(proto.SerializeToString())
    elif ext == '.nnp':
        logger.info("Saving {} as nnp".format(filename))

        nntxt = io.StringIO()
        save(nntxt,
             contents,
             include_params=False,
             variable_batch_size=variable_batch_size,
             extension='.nntxt')
        nntxt.seek(0)

        version = io.StringIO()
        version.write('{}\n'.format(nnp_version()))
        version.seek(0)

        param = io.BytesIO()
        save_parameters(param, extension='.protobuf')
        param.seek(0)

        with get_file_handle_save(filename, ext) as nnp:
            nnp.writestr('nnp_version.txt', version.read())
            nnp.writestr('network.nntxt', nntxt.read())
            nnp.writestr('parameter.protobuf', param.read())