Exemplo n.º 1
0
def generate_case_from_nntxt_str(nntxt_str,
                                 param_format,
                                 dataset_sample_num,
                                 batch_size=None):
    proto = proto_from_str(nntxt_str)
    with generate_csv_png(dataset_sample_num,
                          get_input_size(proto)) as dataset_csv_file:
        # To test dataset, we create a randomly generated dataset.
        for ds in proto.dataset:
            ds.batch_size = batch_size if batch_size else ds.batch_size
            ds.uri = dataset_csv_file
            ds.cache_dir = os.path.join(os.path.dirname(dataset_csv_file),
                                        "data.cache")
        nntxt_io = io.StringIO()
        text_format.PrintMessage(proto, nntxt_io)
        nntxt_io.seek(0)

        version = io.StringIO()
        version.write('{}\n'.format(nnp_version()))
        version.seek(0)

        param = io.BytesIO()
        prepare_parameters(nntxt_str)
        nn.parameter.save_parameters(param, extension=param_format)

        with create_temp_with_dir(NNP_FILE) as temp_nnp_file_name:
            with get_file_handle_save(temp_nnp_file_name, ".nnp") as nnp:
                nnp.writestr('nnp_version.txt', version.read())
                nnp.writestr('network.nntxt', nntxt_io.read())
                nnp.writestr('parameter{}'.format(param_format), param.read())
            yield temp_nnp_file_name
Exemplo n.º 2
0
def _save_parameters(tmp_nnp_file, train_config, nntxt_str):
    base = os.path.dirname(tmp_nnp_file)
    base = os.path.join(base, 'results')

    version_filename = base + '_version.txt'

    with open(version_filename, 'w') as file:
        file.write('{}\n'.format(nnp_version()))

    # This is for testing, start=>
    nntxt_filename = base + '_network.nntxt'
    with open(nntxt_filename, "w") as file:
        file.write(nntxt_str)
    # <= End.

    param_filename = base + '_param.h5'
    nn.parameter.save_parameters(param_filename)

    opti_filenames = save_optimizer_states(base, '.h5', train_config)

    with zipfile.ZipFile(tmp_nnp_file, 'w') as nnp:
        nnp.write(version_filename, 'nnp_version.txt')
        nnp.write(nntxt_filename, "network.nntxt")
        nnp.write(param_filename, 'parameter.h5')
        for f in opti_filenames:
            nnp.write(f, f[len(base) + 1:])

    os.unlink(version_filename)
    os.unlink(param_filename)
    for f in opti_filenames:
        os.unlink(f)
    logger.info("{} is saved.".format(tmp_nnp_file))
Exemplo n.º 3
0
def _save_parameters(args, suffix, epoch, force=False):
    global _save_parameter_info

    if suffix not in _save_parameter_info:
        _save_parameter_info[suffix] = {}
        _save_parameter_info[suffix]['epoch'] = 0
        _save_parameter_info[suffix]['time'] = 0

    current_time = time.time()
    timediff = current_time - _save_parameter_info[suffix]['time']
    epochdiff = epoch - _save_parameter_info[suffix]['epoch']

    globname = os.path.join(args.outdir, 'results_{}_*.nnp'.format(suffix))
    exists = glob.glob(globname)

    base = os.path.join(args.outdir, 'results_{}_{}'.format(suffix, epoch))
    base_candidate = callback.result_base(base, suffix, args.outdir)
    if base_candidate is None:
        if suffix is None or suffix == 'best':
            base = os.path.join(args.outdir, 'results')
    else:
        base = base_candidate

    filename = base + '.nnp'

    if force or (not os.path.exists(filename) and
                 (timediff > 180.0 or epochdiff > 10)):

        # Remove existing nnp before saving new file.
        for exist in exists:
            os.unlink(exist)

        version_filename = base + '_version.txt'

        with open(version_filename, 'w') as file:
            file.write('{}\n'.format(nnp_version()))

        param_filename = base + '_param.protobuf'
        save_parameters(param_filename)

        with zipfile.ZipFile(filename, 'w') as nnp:
            nnp.write(version_filename, 'nnp_version.txt')
            nnp.write(_save_parameter_info['config'],
                      os.path.basename(_save_parameter_info['config']))
            nnp.write(param_filename, 'parameter.protobuf')

        os.unlink(version_filename)
        os.unlink(param_filename)

        _save_parameter_info[suffix]['epoch'] = epoch
        _save_parameter_info[suffix]['time'] = current_time

        callback.save_train_snapshot()
Exemplo n.º 4
0
def _save_parameters(args, suffix, epoch, force=False):
    global _save_parameter_info

    if suffix not in _save_parameter_info:
        _save_parameter_info[suffix] = {}
        _save_parameter_info[suffix]['epoch'] = 0
        _save_parameter_info[suffix]['time'] = 0

    current_time = time.time()
    timediff = current_time - _save_parameter_info[suffix]['time']
    epochdiff = epoch - _save_parameter_info[suffix]['epoch']

    globname = os.path.join(args.outdir, 'results_{}_*.nnp'.format(suffix))
    exists = glob.glob(globname)

    base = os.path.join(args.outdir, 'results_{}_{}'.format(suffix, epoch))
    if suffix == 'best':
        base = os.path.join(args.outdir, 'results')
    filename = base + '.nnp'

    if not os.path.exists(filename) and \
       (force or timediff > 180.0 or epochdiff > 10):

        version_filename = base + '_version.txt'

        with open(version_filename, 'w') as file:
            file.write('{}\n'.format(nnp_version()))

        param_filename = base + '_param.protobuf'
        save_parameters(param_filename)

        with zipfile.ZipFile(filename, 'w') as nnp:
            nnp.write(version_filename, 'nnp_version.txt')
            nnp.write(_save_parameter_info['config'],
                      os.path.basename(_save_parameter_info['config']))
            nnp.write(param_filename, 'parameter.protobuf')

        os.unlink(version_filename)
        os.unlink(param_filename)

        for exist in exists:
            os.unlink(exist)

        _save_parameter_info[suffix]['epoch'] = epoch
        _save_parameter_info[suffix]['time'] = current_time
Exemplo n.º 5
0
def _nnp_file_saver(ctx, filename, ext):
    logger.info("Saving {} as nnp".format(filename))
    nntxt = io.StringIO()
    _nntxt_file_saver(ctx, nntxt, ".nntxt")

    version = io.StringIO()
    version.write('{}\n'.format(nnp_version()))
    version.seek(0)

    param = io.BytesIO()
    if ctx.parameters is None:
        nn.parameter.save_parameters(param, extension='.protobuf')
    else:
        nn.parameter.save_parameters(
            param, ctx.parameters, extension='.protobuf')

    with get_file_handle_save(filename, ext) as nnp:
        nnp.writestr('nnp_version.txt', version.read())
        nnp.writestr('network.nntxt', nntxt.read())
        nnp.writestr('parameter.protobuf', param.read())
Exemplo n.º 6
0
def save(filename, contents, include_params=False, variable_batch_size=True):
    '''Save network definition, inference/training execution
    configurations etc.

    Args:
        filename (str): Filename to store information. The file
            extension is used to determine the saving file format.
            ``.nnp``: (Recommended) Creating a zip archive with nntxt (network
            definition etc.) and h5 (parameters).
            ``.nntxt``: Protobuf in text format.
            ``.protobuf``: Protobuf in binary format (unsafe in terms of
             backward compatibility).
        contents (dict): Information to store.
        include_params (bool): Includes parameter into single file. This is
            ignored when the extension of filename is nnp.
        variable_batch_size (bool):
            By ``True``, the first dimension of all variables is considered
            as batch size, and left as a placeholder
            (more specifically ``-1``). The placeholder dimension will be
            filled during/after loading.

    Example:
        The following example creates a two inputs and two
        outputs MLP, and save the network structure and the initialized
        parameters.

        .. code-block:: python

            import nnabla as nn
            import nnabla.functions as F
            import nnabla.parametric_functions as PF
            from nnabla.utils.save import save

            batch_size = 16
            x0 = nn.Variable([batch_size, 100])
            x1 = nn.Variable([batch_size, 100])
            h1_0 = PF.affine(x0, 100, name='affine1_0')
            h1_1 = PF.affine(x1, 100, name='affine1_0')
            h1 = F.tanh(h1_0 + h1_1)
            h2 = F.tanh(PF.affine(h1, 50, name='affine2'))
            y0 = PF.affine(h2, 10, name='affiney_0')
            y1 = PF.affine(h2, 10, name='affiney_1')

            contents = {
                'networks': [
                    {'name': 'net1',
                     'batch_size': batch_size,
                     'outputs': {'y0': y0, 'y1': y1},
                     'names': {'x0': x0, 'x1': x1}}],
                'executors': [
                    {'name': 'runtime',
                     'network': 'net1',
                     'data': ['x0', 'x1'],
                     'output': ['y0', 'y1']}]}
            save('net.nnp', contents)


        To get a trainable model, use following code instead.

        .. code-block:: python

            contents = {
            'global_config': {'default_context': ctx},
            'training_config':
                {'max_epoch': args.max_epoch,
                 'iter_per_epoch': args_added.iter_per_epoch,
                 'save_best': True},
            'networks': [
                {'name': 'training',
                 'batch_size': args.batch_size,
                 'outputs': {'loss': loss_t},
                 'names': {'x': x, 'y': t, 'loss': loss_t}},
                {'name': 'validation',
                 'batch_size': args.batch_size,
                 'outputs': {'loss': loss_v},
                 'names': {'x': x, 'y': t, 'loss': loss_v}}],
            'optimizers': [
                {'name': 'optimizer',
                 'solver': solver,
                 'network': 'training',
                 'dataset': 'mnist_training',
                 'weight_decay': 0,
                 'lr_decay': 1,
                 'lr_decay_interval': 1,
                 'update_interval': 1}],
            'datasets': [
                {'name': 'mnist_training',
                 'uri': 'MNIST_TRAINING',
                 'cache_dir': args.cache_dir + '/mnist_training.cache/',
                 'variables': {'x': x, 'y': t},
                 'shuffle': True,
                 'batch_size': args.batch_size,
                 'no_image_normalization': True},
                {'name': 'mnist_validation',
                 'uri': 'MNIST_VALIDATION',
                 'cache_dir': args.cache_dir + '/mnist_test.cache/',
                 'variables': {'x': x, 'y': t},
                 'shuffle': False,
                 'batch_size': args.batch_size,
                 'no_image_normalization': True
                 }],
            'monitors': [
                {'name': 'training_loss',
                 'network': 'validation',
                 'dataset': 'mnist_training'},
                {'name': 'validation_loss',
                 'network': 'validation',
                 'dataset': 'mnist_validation'}],
            }


    '''
    _, ext = os.path.splitext(filename)
    if ext == '.nntxt' or ext == '.prototxt':
        logger.info("Saving {} as prototxt".format(filename))
        proto = create_proto(contents, include_params, variable_batch_size)
        with open(filename, 'w') as file:
            text_format.PrintMessage(proto, file)
    elif ext == '.protobuf':
        logger.info("Saving {} as protobuf".format(filename))
        proto = create_proto(contents, include_params, variable_batch_size)
        with open(filename, 'wb') as file:
            file.write(proto.SerializeToString())
    elif ext == '.nnp':
        logger.info("Saving {} as nnp".format(filename))
        try:
            tmpdir = tempfile.mkdtemp()
            save('{}/network.nntxt'.format(tmpdir),
                 contents,
                 include_params=False,
                 variable_batch_size=variable_batch_size)

            with open('{}/nnp_version.txt'.format(tmpdir), 'w') as file:
                file.write('{}\n'.format(nnp_version()))

            save_parameters('{}/parameter.protobuf'.format(tmpdir))

            with zipfile.ZipFile(filename, 'w') as nnp:
                nnp.write('{}/nnp_version.txt'.format(tmpdir),
                          'nnp_version.txt')
                nnp.write('{}/network.nntxt'.format(tmpdir), 'network.nntxt')
                nnp.write('{}/parameter.protobuf'.format(tmpdir),
                          'parameter.protobuf')
        finally:
            shutil.rmtree(tmpdir)
Exemplo n.º 7
0
def _save_parameters(args, suffix, epoch, train_config, force=False):
    global _save_parameter_info

    if suffix not in _save_parameter_info:
        _save_parameter_info[suffix] = {}
        _save_parameter_info[suffix]['epoch'] = 0
        _save_parameter_info[suffix]['time'] = 0

    current_time = time.time()
    timediff = current_time - _save_parameter_info[suffix]['time']
    epochdiff = epoch - _save_parameter_info[suffix]['epoch']

    globname = os.path.join(args.outdir, 'results_{}_*.nnp'.format(suffix))
    exists = glob.glob(globname)

    base = os.path.join(args.outdir, 'results_{}_{}'.format(suffix, epoch))
    base_candidate = callback.result_base(base, suffix, args.outdir)
    if base_candidate is None:
        if suffix is None or suffix == 'best':
            base = os.path.join(args.outdir, 'results')
    else:
        base = base_candidate

    filename = base + '.nnp'

    if force or (not os.path.exists(filename) and
                 (timediff > 180.0 or epochdiff > 10)):

        # Remove existing nnp before saving new file.
        for exist in exists:
            os.unlink(exist)

        version_filename = base + '_version.txt'

        with open(version_filename, 'w') as file:
            file.write('{}\n'.format(nnp_version()))

        param_filename = base + '_param.h5'
        save_parameters(param_filename)

        need_save_opti = train_config.optimizers and epoch % _OPTIMIZER_CHECKPOINT_INTERVAL == 0
        if need_save_opti:
            opti_filenames = save_optimizer_states(base, '.h5', train_config)

        with zipfile.ZipFile(filename, 'w') as nnp:
            nnp.write(version_filename, 'nnp_version.txt')
            nnp.write(_save_parameter_info['config'],
                      os.path.basename(_save_parameter_info['config']))
            nnp.write(param_filename, 'parameter.h5')
            if need_save_opti:
                for f in opti_filenames:
                    nnp.write(f, f[len(base) + 1:])

        os.unlink(version_filename)
        os.unlink(param_filename)
        if need_save_opti:
            for f in opti_filenames:
                os.unlink(f)

        _save_parameter_info[suffix]['epoch'] = epoch
        _save_parameter_info[suffix]['time'] = current_time

        callback.save_train_snapshot()