Exemplo n.º 1
0
def test_parameter_file_load_save():
    module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32),
                                                    (4, 3, 32, 32)])
    proto_variable_inputs = module_creator.get_proto_variable_inputs()
    outputs = module_creator.module(*proto_variable_inputs)
    g = nn.graph_def.get_default_graph_by_variable(outputs)

    with create_temp_with_dir(nnp_file) as tmp_file:
        g.save(tmp_file)
        another = TSTNetNormal()
        variable_inputs = module_creator.get_variable_inputs()
        outputs = g(*variable_inputs)
        ref_outputs = another(*variable_inputs)

        # Should not equal
        with pytest.raises(AssertionError) as excinfo:
            forward_variable_and_check_equal(outputs, ref_outputs)

        # load to local scope
        with nn.parameter_scope('', another.parameter_scope):
            nn.load_parameters(tmp_file)

        another.update_parameter()

        ref_outputs = another(*variable_inputs)
        forward_variable_and_check_equal(outputs, ref_outputs)
Exemplo n.º 2
0
def test_parameter_file_load_save_for_file_object(memory_buffer_format):
    module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32),
                                                    (4, 3, 32, 32)])
    variable_inputs = module_creator.get_variable_inputs()
    a_module = module_creator.module
    outputs = a_module(*variable_inputs)
    another = TSTNetNormal()
    ref_outputs = another(*variable_inputs)
    extension = memory_buffer_format

    # Should not equal
    with pytest.raises(AssertionError) as excinfo:
        forward_variable_and_check_equal(outputs, ref_outputs)

    with io.BytesIO() as param_file:
        nn.save_parameters(param_file,
                           a_module.get_parameters(),
                           extension=extension)
        # load from file
        with nn.parameter_scope('', another.parameter_scope):
            nn.load_parameters(param_file, extension=extension)
        another.update_parameter()

    ref_outputs = another(*variable_inputs)

    # should equal
    forward_variable_and_check_equal(outputs, ref_outputs)
Exemplo n.º 3
0
def test_module_load_save_parameter_file_io(extension, file_format):
    module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32),
                                                    (4, 3, 32, 32)])
    variable_inputs = module_creator.get_variable_inputs()
    a_module = module_creator.module
    outputs = a_module(*variable_inputs)
    another = TSTNetNormal()
    ref_outputs = another(*variable_inputs)

    # Should not equal
    with pytest.raises(AssertionError) as excinfo:
        forward_variable_and_check_equal(outputs, ref_outputs)

    if file_format == 'file_io':
        with create_temp_with_dir("tmp{}".format(extension)) as param_file:
            with open(param_file, "wb") as f:
                a_module.save_parameters(f, extension=extension)
            with open(param_file, "rb") as f:
                another.load_parameters(f, extension=extension)
    elif file_format == 'byte_io':
        with io.BytesIO() as param_file:
            a_module.save_parameters(param_file, extension=extension)
            another.load_parameters(param_file, extension=extension)
    elif file_format == 'str':
        with create_temp_with_dir("tmp{}".format(extension)) as param_file:
            a_module.save_parameters(param_file, extension=extension)
            another.load_parameters(param_file, extension=extension)

    ref_outputs = another(*variable_inputs)

    # should equal
    forward_variable_and_check_equal(outputs, ref_outputs)
Exemplo n.º 4
0
def test_parameter_file_load_save_for_files(parameter_file):
    module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32),
                                                    (4, 3, 32, 32)])
    variable_inputs = module_creator.get_variable_inputs()
    a_module = module_creator.module
    outputs = a_module(*variable_inputs)
    another = TSTNetNormal()
    ref_outputs = another(*variable_inputs)

    # Should not equal
    with pytest.raises(AssertionError) as excinfo:
        forward_variable_and_check_equal(outputs, ref_outputs)

    with create_temp_with_dir(parameter_file) as tmp_file:
        # save to file
        nn.save_parameters(tmp_file, a_module.get_parameters())

        # load from file
        with nn.parameter_scope('', another.parameter_scope):
            nn.load_parameters(tmp_file)
    another.update_parameter()

    ref_outputs = another(*variable_inputs)

    # should equal
    forward_variable_and_check_equal(outputs, ref_outputs)
Exemplo n.º 5
0
def test_multiple_network():
    """This cases assume that user create network multiple times in a ProtoGraph.
    Because no graph_name() is called to name each network, all computation graph
    operators are collected into a network, it looks like multiple networks
    are merged into a network. In testing time, we only need to feed concatenated
    inputs to this network and check concatenate outputs.
    """
    module_creators = [ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]),
                       ModuleCreator(ResUnit(16), [(4, 3, 32, 32)]),
                       ModuleCreator(NestedTestNet(), [(4, 3, 32, 32), (4, 3, 32, 32)])]

    # create graph_def by passing proto_variables as inputs
    with nn.graph_def.graph() as g:
        for module_creator in module_creators:
            module = module_creator.module

            # create proto variables as inputs
            proto_variable_inputs = [nn.ProtoVariable(
                shape) for shape in module_creator.input_shape]

            # generate graph
            outputs = module(*proto_variable_inputs)

    for module_creator, network in zip(module_creators, g.networks.values()):
        # create variable inputs and initialized by random value
        variable_inputs = module_creator.get_variable_inputs()

        # create network by module-like graph_def
        outputs = network(*variable_inputs)

        # create reference network by passing in variable inputs
        ref_outputs = module_creator.module(*variable_inputs)

        # check if outputs are equal
        forward_variable_and_check_equal(outputs, ref_outputs)
Exemplo n.º 6
0
def test_get_default_graph_def_by_name():
    """This case tests retrieving graph using nn.graph_def.get_default_graph() by
    specifying the name of graph.
    """
    module_creators = [ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]),
                       ModuleCreator(ResUnit(16), [(4, 3, 32, 32)]),
                       ModuleCreator(NestedTestNet(), [(4, 3, 32, 32), (4, 3, 32, 32)])]
    network_names = [
        'network1',
        'network2',
        'network3'
    ]

    for module_creator, network_name in zip(module_creators, network_names):
        module = module_creator.module

        # create proto variables as inputs
        proto_variable_inputs = [nn.ProtoVariable(
            shape) for shape in module_creator.input_shape]

        with nn.graph_def.graph_name(network_name):
            # generate graph
            outputs = module(*proto_variable_inputs)

    for module_creator, network_name in zip(module_creators, network_names):
        # create variable inputs and initialized by random value
        variable_inputs = module_creator.get_variable_inputs()

        # get graph from default by name
        g = nn.graph_def.get_default_graph(network_name)

        # create network by module-like graph_def
        outputs = g(*variable_inputs)

        # create reference network by passing in variable inputs
        ref_outputs = module_creator.module(*variable_inputs)

        # check if outputs are equal
        forward_variable_and_check_equal(outputs, ref_outputs)
Exemplo n.º 7
0
def test_get_graph_def_by_name():
    """This cases assume that user creates multiple networks in a ProtoGraph.
    User may specify the name of network(graph) they created by graph_name().
    """
    module_creators = [
        ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]),
        ModuleCreator(ResUnit(16), [(4, 3, 32, 32)]),
        ModuleCreator(NestedTestNet(), [(4, 3, 32, 32), (4, 3, 32, 32)])
    ]
    network_names = ['network1', 'network2', 'network3']

    # create graph_def by passing proto_variables as inputs
    with nn.graph_def.graph() as g:
        for module_creator, network_name in zip(module_creators,
                                                network_names):
            module = module_creator.module

            # create proto variables as inputs
            proto_variable_inputs = [
                nn.ProtoVariable(shape) for shape in module_creator.input_shape
            ]

            with nn.graph_def.graph_name(network_name):
                # generate graph
                outputs = module(*proto_variable_inputs)

    for module_creator, network_name in zip(module_creators, network_names):
        # create variable inputs and initialized by random value
        variable_inputs = module_creator.get_variable_inputs()

        # create network by module-like graph_def
        outputs = g[network_name](*variable_inputs)

        # create reference network by passing in variable inputs
        ref_outputs = module_creator.module(*variable_inputs)

        # check if outputs are equal
        forward_variable_and_check_equal(outputs, ref_outputs)
Exemplo n.º 8
0
def test_parameter_file_load_save_using_global():
    module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32),
                                                    (4, 3, 32, 32)])
    proto_variable_inputs = module_creator.get_proto_variable_inputs()
    outputs = module_creator.module(*proto_variable_inputs)
    g = nn.graph_def.get_default_graph_by_variable(outputs)
    g.save(nnp_file)
    another = TSTNetNormal()
    variable_inputs = module_creator.get_variable_inputs()
    outputs = g(*variable_inputs)
    ref_outputs = another(*variable_inputs)

    # Should not equal
    with pytest.raises(AssertionError) as excinfo:
        forward_variable_and_check_equal(outputs, ref_outputs)

    # load to global scope
    nn.load_parameters(nnp_file)
    params = nn.get_parameters()
    another.set_parameters(params)

    ref_outputs = another(*variable_inputs)
    forward_variable_and_check_equal(outputs, ref_outputs)
Exemplo n.º 9
0
    # Create another graph with the same architecture
    e2 = Example()
    assert not e2.get_parameters(), "It doesn't have any parameters so far."

    # Setting parameters from an existing model)
    e2.set_parameters(e.get_parameters())
    assert e.get_parameters() == e2.get_parameters(
    ), "They have the same parameters."

    assert '@cb/conv/W' in e.get_parameters()
    assert '@cb/conv/W' in e2.get_parameters()


@pytest.mark.parametrize(
    "module_creator",
    [ModuleCreator(TSTNetAbnormal(), [(4, 3, 32, 32), (4, 3, 32, 32)])])
def test_unsupported_module_definition(module_creator):
    # Since we use global graphdef, we should reset it beforehand
    # This is already done in test fixture.
    # nn.graph_def.reset_default_graph()

    # get module from test parameters
    module = module_creator.module

    # create variable inputs and initialized by random value
    variable_inputs = module_creator.get_variable_inputs()

    # create reference network by passing in variable inputs
    ref_outputs = module(*variable_inputs)

    # create proto variable inputs
Exemplo n.º 10
0
    e = Example()
    h = e(x)

    # Create another graph with the same architecture
    e2 = Example()
    assert not e2.get_parameters(), "It doesn't have any parameters so far."

    # Setting parameters from an existing model)
    e2.set_parameters(e.get_parameters())
    assert e.get_parameters() == e2.get_parameters(), "They have the same parameters."

    assert '@cb/conv/W' in e.get_parameters()
    assert '@cb/conv/W' in e2.get_parameters()


@pytest.mark.parametrize("module_creator", [ModuleCreator(TSTNetAbnormal(), [(4, 3, 32, 32), (4, 3, 32, 32)])])
def test_unsupported_module_definition(module_creator):
    # Since we use global graphdef, we should reset it beforehand
    # This is already done in test fixture.
    # nn.graph_def.reset_default_graph()

    # get module from test parameters
    module = module_creator.module

    # create variable inputs and initialized by random value
    variable_inputs = module_creator.get_variable_inputs()

    # create reference network by passing in variable inputs
    ref_outputs = module(*variable_inputs)

    # create proto variable inputs