Ejemplo n.º 1
0
def test_networkpass_on_generate_function(module):
    '''This test tests whether equivalency between new or old implementation of nnp_graph
    '''
    _, inputs = module
    verbose = 1
    callback = nnp_graph.NnpNetworkPass(verbose)

    @callback.on_generate_function_by_name('Convolution')
    def change_convolution_param(f):
        print('{}'.format(f.proto.convolution_param.pad.dim[:]))
        f.proto.convolution_param.pad.dim[:] = [1, 1]
        return f

    @callback.on_function_pass_by_type('Affine')
    def change_affine_param(f, variables, param_scope):
        param_name = f.inputs[1].proto.name
        input_shape = f.inputs[0].proto.shape.dim[:]
        w_shape = f.inputs[1].proto.shape.dim[:]
        rng = np.random.RandomState(388)
        with nn.parameter_scope('', param_scope):
            W = nn.Variable.from_numpy_array(
                rng.randn(np.prod(input_shape[1:]), w_shape[1]))
            nn.parameter.set_parameter(param_name, W)
            W.need_grad = True

    with get_saved_test_model(module) as nnp_file:
        ref_nnp = legacy_nnp_graph.NnpLoader(nnp_file)
        nnp = nnp_graph.NnpLoader(nnp_file)
        for ref_v, v in zip(nnp_check(ref_nnp, 'left', callback),
                            nnp_check(nnp, 'right', callback)):
            verify_equivalence(ref_v, v)
Ejemplo n.º 2
0
def test_networkpass_use_up_to(module):
    '''This test tests whether equivalency between new or old implementation of nnp_graph
    '''
    _, inputs = module
    verbose = 1
    callback = nnp_graph.NnpNetworkPass(verbose)
    callback.use_up_to('Tanh_out_1')
    callback.use_up_to('Convolution_out_3')
    with get_saved_test_model(module) as nnp_file:
        ref_nnp = legacy_nnp_graph.NnpLoader(nnp_file)
        nnp = nnp_graph.NnpLoader(nnp_file)
        for ref_v, v in zip(nnp_check(ref_nnp, 'left', callback),
                            nnp_check(nnp, 'right', callback)):
            verify_equivalence(ref_v, v)
Ejemplo n.º 3
0
def test_networkpass_remove_and_rewire(module):
    '''This test tests whether equivalency between new or old implementation of nnp_graph
    '''
    _, inputs = module
    verbose = 1
    callback = nnp_graph.NnpNetworkPass(verbose)
    callback.remove_and_rewire('affine1_1')
    callback.remove_and_rewire('c1-c1')
    with get_saved_test_model(module) as nnp_file:
        ref_nnp = legacy_nnp_graph.NnpLoader(nnp_file)
        nnp = nnp_graph.NnpLoader(nnp_file)
        for ref_v, v in zip(nnp_check(ref_nnp, 'left', callback),
                            nnp_check(nnp, 'right', callback)):
            verify_equivalence(ref_v, v)
Ejemplo n.º 4
0
def test_networkpass_fix_parameter(module):
    '''This test tests whether equivalency between new or old implementation of nnp_graph
    '''
    _, inputs = module
    verbose = 1
    callback = nnp_graph.NnpNetworkPass(verbose)
    callback.fix_parameters()
    with get_saved_test_model(module) as nnp_file:
        nnp = nnp_graph.NnpLoader(nnp_file)
        assert_parameter_scope_empty()
        for network_name in sorted(nnp.get_network_names()):
            network = nnp.get_network(
                network_name, batch_size=32, callback=callback)
            assert_parameter_scope_empty()
Ejemplo n.º 5
0
def test_networkpass_set_variable(module):
    '''This test tests whether equivalency between new or old implementation of nnp_graph
    '''
    _, inputs = module
    verbose = 1
    callback = nnp_graph.NnpNetworkPass(verbose)
    ref_callback = legacy_nnp_graph.NnpNetworkPass(verbose)
    for inp_name, inp_shape in inputs:
        inp_shape = (1, *inp_shape[1:])  # change shape
        callback.set_variable(inp_name, nn.Variable(inp_shape))
        ref_callback.set_variable(inp_name, nn.Variable(inp_shape))
    with get_saved_test_model(module) as nnp_file:
        ref_nnp = legacy_nnp_graph.NnpLoader(nnp_file)
        nnp = nnp_graph.NnpLoader(nnp_file)
        for ref_v, v in zip(nnp_check(ref_nnp, 'left', ref_callback),
                            nnp_check(nnp, 'right', callback)):
            verify_equivalence(ref_v, v)
Ejemplo n.º 6
0
def test_nnp_load_parameter_scope(module):
    '''This test tests whether equivalency between new or old implementation of nnp_graph
    '''
    _, inputs = module
    verbose = 1
    callback = nnp_graph.NnpNetworkPass(verbose)

    @callback.on_generate_function_by_name('Convolution')
    def change_convolution_param(f):
        print('{}'.format(f.proto.convolution_param.pad.dim[:]))
        f.proto.convolution_param.pad.dim[:] = [1, 1]
        return f

    @callback.on_function_pass_by_type('Affine')
    def change_affine_param(f, variables, param_scope):
        param_name = f.inputs[1].proto.name
        input_shape = f.inputs[0].proto.shape.dim[:]
        w_shape = f.inputs[1].proto.shape.dim[:]
        rng = np.random.RandomState(388)
        with nn.parameter_scope('', param_scope):
            W = nn.Variable.from_numpy_array(
                rng.randn(np.prod(input_shape[1:]), w_shape[1]))
            W.need_grad = True
            nn.parameter.set_parameter(param_name, W)

    ref_params = {}
    with get_saved_test_model(module) as nnp_file:
        nnp = legacy_nnp_graph.NnpLoader(nnp_file)
        for network_name in sorted(nnp.get_network_names()):
            network = nnp.get_network(network_name,
                                      batch_size=32,
                                      callback=callback)
            ref_params[network_name] = nn.get_parameters().copy()
        nn.clear_parameters()

        params = {}
        nnp = nnp_graph.NnpLoader(nnp_file)
        assert_parameter_scope_empty()
        for network_name in sorted(nnp.get_network_names()):
            network = nnp.get_network(network_name,
                                      batch_size=32,
                                      callback=callback)
            params[network_name] = nn.get_parameters()

    assert_parameters_equal(ref_params, params)
Ejemplo n.º 7
0
 def modifier(x):
     callback = nnp_graph.NnpNetworkPass(1)
     callback.remove_and_rewire('Affine')
     callback.set_variable('Affine_in', x)
     return callback