Exemplo n.º 1
0
def test_networkpass_multiple_ops(ops):
    def prepare_input(v):
        rng = np.random.RandomState(389)
        v.d = rng.randn(*v.d.shape)
    x = nn.Variable((4, 3))
    prepare_input(x)
    for n, ref_n in zip(nnp_get_topo(prepare_model(ops.s_module), 'left', ops.modifier(x)),
                        nnp_get_topo(prepare_model(ops.d_module), 'right')):
        ref_inputs = list(ref_n.inputs.values())
        ref_outputs = list(ref_n.outputs.values())
        inputs = list(n.inputs.values())
        outputs = list(n.outputs.values())
        prepare_input(ref_inputs[0])
        for ref_d, d in zip(forward_variable(ref_inputs, ref_outputs, 'left', True),
                            forward_variable(inputs, outputs, 'right', True)):
            assert_allclose(d, ref_d)
Exemplo n.º 2
0
def verify_equivalence(ref_v, v):
    if isinstance(ref_v, legacy_nnp_graph.NnpNetwork):
        ref_inputs = list(ref_v.inputs.values())
        inputs = list(v.inputs.values())

        ref_outputs = list(ref_v.outputs.values())
        outputs = list(v.outputs.values())
        for ref_d, d in zip(forward_variable(ref_inputs, ref_outputs, 'left'),
                            forward_variable(inputs, outputs, 'right')):
            assert_allclose(ref_d, d)
        assert_topology(ref_outputs, outputs)

    elif isinstance(ref_v, nn.Variable):
        compare_nn_variable_metadata(ref_v, v)
    elif isinstance(ref_v, tuple):
        compare_nn_variable_with_name(ref_v, v)
    else:
        print("compare {} <==> {}".format(ref_v, v))
        assert ref_v == v
Exemplo n.º 3
0
def load_model_from_nnp_graph_and_forward(nnp_file, batch_size):
    nnp = nnp_graph.NnpLoader(nnp_file)
    network = nnp.get_network(nnp.get_network_names()[0],
                              batch_size=batch_size)
    inputs = list(network.inputs.values())
    outputs = list(network.outputs.values())
    out = []
    for d in forward_variable(inputs, outputs, 'left'):
        out.append(d)
    return d