def test_parameter_file_load_save(): module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]) proto_variable_inputs = module_creator.get_proto_variable_inputs() outputs = module_creator.module(*proto_variable_inputs) g = nn.graph_def.get_default_graph_by_variable(outputs) with create_temp_with_dir(nnp_file) as tmp_file: g.save(tmp_file) another = TSTNetNormal() variable_inputs = module_creator.get_variable_inputs() outputs = g(*variable_inputs) ref_outputs = another(*variable_inputs) # Should not equal with pytest.raises(AssertionError) as excinfo: forward_variable_and_check_equal(outputs, ref_outputs) # load to local scope with nn.parameter_scope('', another.parameter_scope): nn.load_parameters(tmp_file) another.update_parameter() ref_outputs = another(*variable_inputs) forward_variable_and_check_equal(outputs, ref_outputs)
def test_parameter_file_load_save_using_global(): module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]) proto_variable_inputs = module_creator.get_proto_variable_inputs() outputs = module_creator.module(*proto_variable_inputs) g = nn.graph_def.get_default_graph_by_variable(outputs) g.save(nnp_file) another = TSTNetNormal() variable_inputs = module_creator.get_variable_inputs() outputs = g(*variable_inputs) ref_outputs = another(*variable_inputs) # Should not equal with pytest.raises(AssertionError) as excinfo: forward_variable_and_check_equal(outputs, ref_outputs) # load to global scope nn.load_parameters(nnp_file) params = nn.get_parameters() another.set_parameters(params) ref_outputs = another(*variable_inputs) forward_variable_and_check_equal(outputs, ref_outputs)