コード例 #1
0
def test_parameter_file_load_save():
    module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32),
                                                    (4, 3, 32, 32)])
    proto_variable_inputs = module_creator.get_proto_variable_inputs()
    outputs = module_creator.module(*proto_variable_inputs)
    g = nn.graph_def.get_default_graph_by_variable(outputs)

    with create_temp_with_dir(nnp_file) as tmp_file:
        g.save(tmp_file)
        another = TSTNetNormal()
        variable_inputs = module_creator.get_variable_inputs()
        outputs = g(*variable_inputs)
        ref_outputs = another(*variable_inputs)

        # Should not equal
        with pytest.raises(AssertionError) as excinfo:
            forward_variable_and_check_equal(outputs, ref_outputs)

        # load to local scope
        with nn.parameter_scope('', another.parameter_scope):
            nn.load_parameters(tmp_file)

        another.update_parameter()

        ref_outputs = another(*variable_inputs)
        forward_variable_and_check_equal(outputs, ref_outputs)
コード例 #2
0
def test_module_load_save_parameter_file_io(extension, file_format):
    module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32),
                                                    (4, 3, 32, 32)])
    variable_inputs = module_creator.get_variable_inputs()
    a_module = module_creator.module
    outputs = a_module(*variable_inputs)
    another = TSTNetNormal()
    ref_outputs = another(*variable_inputs)

    # Should not equal
    with pytest.raises(AssertionError) as excinfo:
        forward_variable_and_check_equal(outputs, ref_outputs)

    if file_format == 'file_io':
        with create_temp_with_dir("tmp{}".format(extension)) as param_file:
            with open(param_file, "wb") as f:
                a_module.save_parameters(f, extension=extension)
            with open(param_file, "rb") as f:
                another.load_parameters(f, extension=extension)
    elif file_format == 'byte_io':
        with io.BytesIO() as param_file:
            a_module.save_parameters(param_file, extension=extension)
            another.load_parameters(param_file, extension=extension)
    elif file_format == 'str':
        with create_temp_with_dir("tmp{}".format(extension)) as param_file:
            a_module.save_parameters(param_file, extension=extension)
            another.load_parameters(param_file, extension=extension)

    ref_outputs = another(*variable_inputs)

    # should equal
    forward_variable_and_check_equal(outputs, ref_outputs)
コード例 #3
0
def test_parameter_file_load_save_for_files(parameter_file):
    module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32),
                                                    (4, 3, 32, 32)])
    variable_inputs = module_creator.get_variable_inputs()
    a_module = module_creator.module
    outputs = a_module(*variable_inputs)
    another = TSTNetNormal()
    ref_outputs = another(*variable_inputs)

    # Should not equal
    with pytest.raises(AssertionError) as excinfo:
        forward_variable_and_check_equal(outputs, ref_outputs)

    with create_temp_with_dir(parameter_file) as tmp_file:
        # save to file
        nn.save_parameters(tmp_file, a_module.get_parameters())

        # load from file
        with nn.parameter_scope('', another.parameter_scope):
            nn.load_parameters(tmp_file)
    another.update_parameter()

    ref_outputs = another(*variable_inputs)

    # should equal
    forward_variable_and_check_equal(outputs, ref_outputs)
コード例 #4
0
def test_parameter_file_load_save_for_file_object(memory_buffer_format):
    module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32),
                                                    (4, 3, 32, 32)])
    variable_inputs = module_creator.get_variable_inputs()
    a_module = module_creator.module
    outputs = a_module(*variable_inputs)
    another = TSTNetNormal()
    ref_outputs = another(*variable_inputs)
    extension = memory_buffer_format

    # Should not equal
    with pytest.raises(AssertionError) as excinfo:
        forward_variable_and_check_equal(outputs, ref_outputs)

    with io.BytesIO() as param_file:
        nn.save_parameters(param_file,
                           a_module.get_parameters(),
                           extension=extension)
        # load from file
        with nn.parameter_scope('', another.parameter_scope):
            nn.load_parameters(param_file, extension=extension)
        another.update_parameter()

    ref_outputs = another(*variable_inputs)

    # should equal
    forward_variable_and_check_equal(outputs, ref_outputs)
コード例 #5
0
def test_parameter_file_load_save_using_global():
    module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32),
                                                    (4, 3, 32, 32)])
    proto_variable_inputs = module_creator.get_proto_variable_inputs()
    outputs = module_creator.module(*proto_variable_inputs)
    g = nn.graph_def.get_default_graph_by_variable(outputs)
    g.save(nnp_file)
    another = TSTNetNormal()
    variable_inputs = module_creator.get_variable_inputs()
    outputs = g(*variable_inputs)
    ref_outputs = another(*variable_inputs)

    # Should not equal
    with pytest.raises(AssertionError) as excinfo:
        forward_variable_and_check_equal(outputs, ref_outputs)

    # load to global scope
    nn.load_parameters(nnp_file)
    params = nn.get_parameters()
    another.set_parameters(params)

    ref_outputs = another(*variable_inputs)
    forward_variable_and_check_equal(outputs, ref_outputs)