def check_nbla_infer(tmpdir, x, y, batch_size, on_memory): if not command_exists('nbla'): pytest.skip('An executable `nbla` is not in path.') # A. save a created graph to nnp. contents = { 'networks': [{ 'name': 'graph', 'batch_size': 1, 'outputs': { 'y': y }, 'names': { 'x': x } }], 'executors': [{ 'name': 'runtime', 'network': 'graph', 'data': ['x'], 'output': ['y'] }] } from nnabla.utils.save import save tmpdir.ensure(dir=True) tmppath = tmpdir.join('tmp.nnp') nnp_file = tmppath.strpath save(nnp_file, contents) # B. Get result with nnp_graph from nnabla.utils import nnp_graph nnp = nnp_graph.NnpLoader(nnp_file) graph = nnp.get_network('graph', batch_size=batch_size) x2 = graph.inputs['x'] y2 = graph.outputs['y'] x2.d = np.random.randn(*x2.shape).astype(np.float32) y2.forward() # C. Get result with nbla input_bin = tmpdir.join('tmp_in.bin') input_bin_file = input_bin.strpath x2.d.tofile(input_bin_file) output_bin = tmpdir.join('tmp_out') if on_memory: check_call([ 'nbla', 'infer', '-O', '-e', 'runtime', '-b', str(batch_size), '-o', output_bin.strpath, nnp_file, input_bin_file ]) else: check_call([ 'nbla', 'infer', '-e', 'runtime', '-b', str(batch_size), '-o', output_bin.strpath, nnp_file, input_bin_file ]) # D. Compare y3 = np.fromfile(output_bin.strpath + '_0.bin', dtype=np.float32).reshape(y2.shape) assert np.allclose(y2.d, y3)
def save_model(self): from nnabla.utils.save import save with nn.parameter_scope(self.name_q): save(self.save_path.get_filepath('qnet_{:08d}.nnp'.format(self.update_count)), {'networks': [ {'name': 'qnet', 'batch_size': self.v.s.shape[0], 'outputs': {'q': self.v.q}, 'names': {'s': self.v.s}} ] })