Esempio n. 1
0
def check_nbla_infer(tmpdir, x, y, batch_size, on_memory):
    if not command_exists('nbla'):
        pytest.skip('An executable `nbla` is not in path.')

    # A. save a created graph to nnp.
    contents = {
        'networks': [{
            'name': 'graph',
            'batch_size': 1,
            'outputs': {
                'y': y
            },
            'names': {
                'x': x
            }
        }],
        'executors': [{
            'name': 'runtime',
            'network': 'graph',
            'data': ['x'],
            'output': ['y']
        }]
    }

    from nnabla.utils.save import save
    tmpdir.ensure(dir=True)
    tmppath = tmpdir.join('tmp.nnp')
    nnp_file = tmppath.strpath
    save(nnp_file, contents)

    # B. Get result with nnp_graph
    from nnabla.utils import nnp_graph
    nnp = nnp_graph.NnpLoader(nnp_file)
    graph = nnp.get_network('graph', batch_size=batch_size)
    x2 = graph.inputs['x']
    y2 = graph.outputs['y']
    x2.d = np.random.randn(*x2.shape).astype(np.float32)
    y2.forward()

    # C. Get result with nbla
    input_bin = tmpdir.join('tmp_in.bin')
    input_bin_file = input_bin.strpath
    x2.d.tofile(input_bin_file)

    output_bin = tmpdir.join('tmp_out')
    if on_memory:
        check_call([
            'nbla', 'infer', '-O', '-e', 'runtime', '-b',
            str(batch_size), '-o', output_bin.strpath, nnp_file, input_bin_file
        ])
    else:
        check_call([
            'nbla', 'infer', '-e', 'runtime', '-b',
            str(batch_size), '-o', output_bin.strpath, nnp_file, input_bin_file
        ])

    # D. Compare
    y3 = np.fromfile(output_bin.strpath + '_0.bin',
                     dtype=np.float32).reshape(y2.shape)
    assert np.allclose(y2.d, y3)
Esempio n. 2
0
 def save_model(self):
     from nnabla.utils.save import save
     with nn.parameter_scope(self.name_q):
         save(self.save_path.get_filepath('qnet_{:08d}.nnp'.format(self.update_count)),
              {'networks': [
                  {'name': 'qnet',
                   'batch_size': self.v.s.shape[0],
                   'outputs': {'q': self.v.q},
                   'names': {'s': self.v.s}}
              ]
         })