def test_parameter_file_load_save(): module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]) proto_variable_inputs = module_creator.get_proto_variable_inputs() outputs = module_creator.module(*proto_variable_inputs) g = nn.graph_def.get_default_graph_by_variable(outputs) with create_temp_with_dir(nnp_file) as tmp_file: g.save(tmp_file) another = TSTNetNormal() variable_inputs = module_creator.get_variable_inputs() outputs = g(*variable_inputs) ref_outputs = another(*variable_inputs) # Should not equal with pytest.raises(AssertionError) as excinfo: forward_variable_and_check_equal(outputs, ref_outputs) # load to local scope with nn.parameter_scope('', another.parameter_scope): nn.load_parameters(tmp_file) another.update_parameter() ref_outputs = another(*variable_inputs) forward_variable_and_check_equal(outputs, ref_outputs)
def test_module_load_save_parameter_file_io(extension, file_format): module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]) variable_inputs = module_creator.get_variable_inputs() a_module = module_creator.module outputs = a_module(*variable_inputs) another = TSTNetNormal() ref_outputs = another(*variable_inputs) # Should not equal with pytest.raises(AssertionError) as excinfo: forward_variable_and_check_equal(outputs, ref_outputs) if file_format == 'file_io': with create_temp_with_dir("tmp{}".format(extension)) as param_file: with open(param_file, "wb") as f: a_module.save_parameters(f, extension=extension) with open(param_file, "rb") as f: another.load_parameters(f, extension=extension) elif file_format == 'byte_io': with io.BytesIO() as param_file: a_module.save_parameters(param_file, extension=extension) another.load_parameters(param_file, extension=extension) elif file_format == 'str': with create_temp_with_dir("tmp{}".format(extension)) as param_file: a_module.save_parameters(param_file, extension=extension) another.load_parameters(param_file, extension=extension) ref_outputs = another(*variable_inputs) # should equal forward_variable_and_check_equal(outputs, ref_outputs)
def test_parameter_file_load_save_for_files(parameter_file): module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]) variable_inputs = module_creator.get_variable_inputs() a_module = module_creator.module outputs = a_module(*variable_inputs) another = TSTNetNormal() ref_outputs = another(*variable_inputs) # Should not equal with pytest.raises(AssertionError) as excinfo: forward_variable_and_check_equal(outputs, ref_outputs) with create_temp_with_dir(parameter_file) as tmp_file: # save to file nn.save_parameters(tmp_file, a_module.get_parameters()) # load from file with nn.parameter_scope('', another.parameter_scope): nn.load_parameters(tmp_file) another.update_parameter() ref_outputs = another(*variable_inputs) # should equal forward_variable_and_check_equal(outputs, ref_outputs)
def test_parameter_file_load_save_for_file_object(memory_buffer_format): module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]) variable_inputs = module_creator.get_variable_inputs() a_module = module_creator.module outputs = a_module(*variable_inputs) another = TSTNetNormal() ref_outputs = another(*variable_inputs) extension = memory_buffer_format # Should not equal with pytest.raises(AssertionError) as excinfo: forward_variable_and_check_equal(outputs, ref_outputs) with io.BytesIO() as param_file: nn.save_parameters(param_file, a_module.get_parameters(), extension=extension) # load from file with nn.parameter_scope('', another.parameter_scope): nn.load_parameters(param_file, extension=extension) another.update_parameter() ref_outputs = another(*variable_inputs) # should equal forward_variable_and_check_equal(outputs, ref_outputs)
def test_parameter_file_load_save_using_global(): module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]) proto_variable_inputs = module_creator.get_proto_variable_inputs() outputs = module_creator.module(*proto_variable_inputs) g = nn.graph_def.get_default_graph_by_variable(outputs) g.save(nnp_file) another = TSTNetNormal() variable_inputs = module_creator.get_variable_inputs() outputs = g(*variable_inputs) ref_outputs = another(*variable_inputs) # Should not equal with pytest.raises(AssertionError) as excinfo: forward_variable_and_check_equal(outputs, ref_outputs) # load to global scope nn.load_parameters(nnp_file) params = nn.get_parameters() another.set_parameters(params) ref_outputs = another(*variable_inputs) forward_variable_and_check_equal(outputs, ref_outputs)