def test_multiple_network(): """This cases assume that user create network multiple times in a ProtoGraph. Because no graph_name() is called to name each network, all computation graph operators are collected into a network, it looks like multiple networks are merged into a network. In testing time, we only need to feed concatenated inputs to this network and check concatenate outputs. """ module_creators = [ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]), ModuleCreator(ResUnit(16), [(4, 3, 32, 32)]), ModuleCreator(NestedTestNet(), [(4, 3, 32, 32), (4, 3, 32, 32)])] # create graph_def by passing proto_variables as inputs with nn.graph_def.graph() as g: for module_creator in module_creators: module = module_creator.module # create proto variables as inputs proto_variable_inputs = [nn.ProtoVariable( shape) for shape in module_creator.input_shape] # generate graph outputs = module(*proto_variable_inputs) for module_creator, network in zip(module_creators, g.networks.values()): # create variable inputs and initialized by random value variable_inputs = module_creator.get_variable_inputs() # create network by module-like graph_def outputs = network(*variable_inputs) # create reference network by passing in variable inputs ref_outputs = module_creator.module(*variable_inputs) # check if outputs are equal forward_variable_and_check_equal(outputs, ref_outputs)
def test_parameter_file_load_save(): module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]) proto_variable_inputs = module_creator.get_proto_variable_inputs() outputs = module_creator.module(*proto_variable_inputs) g = nn.graph_def.get_default_graph_by_variable(outputs) with create_temp_with_dir(nnp_file) as tmp_file: g.save(tmp_file) another = TSTNetNormal() variable_inputs = module_creator.get_variable_inputs() outputs = g(*variable_inputs) ref_outputs = another(*variable_inputs) # Should not equal with pytest.raises(AssertionError) as excinfo: forward_variable_and_check_equal(outputs, ref_outputs) # load to local scope with nn.parameter_scope('', another.parameter_scope): nn.load_parameters(tmp_file) another.update_parameter() ref_outputs = another(*variable_inputs) forward_variable_and_check_equal(outputs, ref_outputs)
def test_module_load_save_parameter_file_io(extension, file_format): module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]) variable_inputs = module_creator.get_variable_inputs() a_module = module_creator.module outputs = a_module(*variable_inputs) another = TSTNetNormal() ref_outputs = another(*variable_inputs) # Should not equal with pytest.raises(AssertionError) as excinfo: forward_variable_and_check_equal(outputs, ref_outputs) if file_format == 'file_io': with create_temp_with_dir("tmp{}".format(extension)) as param_file: with open(param_file, "wb") as f: a_module.save_parameters(f, extension=extension) with open(param_file, "rb") as f: another.load_parameters(f, extension=extension) elif file_format == 'byte_io': with io.BytesIO() as param_file: a_module.save_parameters(param_file, extension=extension) another.load_parameters(param_file, extension=extension) elif file_format == 'str': with create_temp_with_dir("tmp{}".format(extension)) as param_file: a_module.save_parameters(param_file, extension=extension) another.load_parameters(param_file, extension=extension) ref_outputs = another(*variable_inputs) # should equal forward_variable_and_check_equal(outputs, ref_outputs)
def test_parameter_file_load_save_for_file_object(memory_buffer_format): module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]) variable_inputs = module_creator.get_variable_inputs() a_module = module_creator.module outputs = a_module(*variable_inputs) another = TSTNetNormal() ref_outputs = another(*variable_inputs) extension = memory_buffer_format # Should not equal with pytest.raises(AssertionError) as excinfo: forward_variable_and_check_equal(outputs, ref_outputs) with io.BytesIO() as param_file: nn.save_parameters(param_file, a_module.get_parameters(), extension=extension) # load from file with nn.parameter_scope('', another.parameter_scope): nn.load_parameters(param_file, extension=extension) another.update_parameter() ref_outputs = another(*variable_inputs) # should equal forward_variable_and_check_equal(outputs, ref_outputs)
def test_parameter_file_load_save_for_files(parameter_file): module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]) variable_inputs = module_creator.get_variable_inputs() a_module = module_creator.module outputs = a_module(*variable_inputs) another = TSTNetNormal() ref_outputs = another(*variable_inputs) # Should not equal with pytest.raises(AssertionError) as excinfo: forward_variable_and_check_equal(outputs, ref_outputs) with create_temp_with_dir(parameter_file) as tmp_file: # save to file nn.save_parameters(tmp_file, a_module.get_parameters()) # load from file with nn.parameter_scope('', another.parameter_scope): nn.load_parameters(tmp_file) another.update_parameter() ref_outputs = another(*variable_inputs) # should equal forward_variable_and_check_equal(outputs, ref_outputs)
def test_get_default_graph_def_by_name(): """This case tests retrieving graph using nn.graph_def.get_default_graph() by specifying the name of graph. """ module_creators = [ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]), ModuleCreator(ResUnit(16), [(4, 3, 32, 32)]), ModuleCreator(NestedTestNet(), [(4, 3, 32, 32), (4, 3, 32, 32)])] network_names = [ 'network1', 'network2', 'network3' ] for module_creator, network_name in zip(module_creators, network_names): module = module_creator.module # create proto variables as inputs proto_variable_inputs = [nn.ProtoVariable( shape) for shape in module_creator.input_shape] with nn.graph_def.graph_name(network_name): # generate graph outputs = module(*proto_variable_inputs) for module_creator, network_name in zip(module_creators, network_names): # create variable inputs and initialized by random value variable_inputs = module_creator.get_variable_inputs() # get graph from default by name g = nn.graph_def.get_default_graph(network_name) # create network by module-like graph_def outputs = g(*variable_inputs) # create reference network by passing in variable inputs ref_outputs = module_creator.module(*variable_inputs) # check if outputs are equal forward_variable_and_check_equal(outputs, ref_outputs)
def test_get_graph_def_by_name(): """This cases assume that user creates multiple networks in a ProtoGraph. User may specify the name of network(graph) they created by graph_name(). """ module_creators = [ ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]), ModuleCreator(ResUnit(16), [(4, 3, 32, 32)]), ModuleCreator(NestedTestNet(), [(4, 3, 32, 32), (4, 3, 32, 32)]) ] network_names = ['network1', 'network2', 'network3'] # create graph_def by passing proto_variables as inputs with nn.graph_def.graph() as g: for module_creator, network_name in zip(module_creators, network_names): module = module_creator.module # create proto variables as inputs proto_variable_inputs = [ nn.ProtoVariable(shape) for shape in module_creator.input_shape ] with nn.graph_def.graph_name(network_name): # generate graph outputs = module(*proto_variable_inputs) for module_creator, network_name in zip(module_creators, network_names): # create variable inputs and initialized by random value variable_inputs = module_creator.get_variable_inputs() # create network by module-like graph_def outputs = g[network_name](*variable_inputs) # create reference network by passing in variable inputs ref_outputs = module_creator.module(*variable_inputs) # check if outputs are equal forward_variable_and_check_equal(outputs, ref_outputs)
def test_parameter_file_load_save_using_global(): module_creator = ModuleCreator(TSTNetNormal(), [(4, 3, 32, 32), (4, 3, 32, 32)]) proto_variable_inputs = module_creator.get_proto_variable_inputs() outputs = module_creator.module(*proto_variable_inputs) g = nn.graph_def.get_default_graph_by_variable(outputs) g.save(nnp_file) another = TSTNetNormal() variable_inputs = module_creator.get_variable_inputs() outputs = g(*variable_inputs) ref_outputs = another(*variable_inputs) # Should not equal with pytest.raises(AssertionError) as excinfo: forward_variable_and_check_equal(outputs, ref_outputs) # load to global scope nn.load_parameters(nnp_file) params = nn.get_parameters() another.set_parameters(params) ref_outputs = another(*variable_inputs) forward_variable_and_check_equal(outputs, ref_outputs)
# Create another graph with the same architecture e2 = Example() assert not e2.get_parameters(), "It doesn't have any parameters so far." # Setting parameters from an existing model) e2.set_parameters(e.get_parameters()) assert e.get_parameters() == e2.get_parameters( ), "They have the same parameters." assert '@cb/conv/W' in e.get_parameters() assert '@cb/conv/W' in e2.get_parameters() @pytest.mark.parametrize( "module_creator", [ModuleCreator(TSTNetAbnormal(), [(4, 3, 32, 32), (4, 3, 32, 32)])]) def test_unsupported_module_definition(module_creator): # Since we use global graphdef, we should reset it beforehand # This is already done in test fixture. # nn.graph_def.reset_default_graph() # get module from test parameters module = module_creator.module # create variable inputs and initialized by random value variable_inputs = module_creator.get_variable_inputs() # create reference network by passing in variable inputs ref_outputs = module(*variable_inputs) # create proto variable inputs
e = Example() h = e(x) # Create another graph with the same architecture e2 = Example() assert not e2.get_parameters(), "It doesn't have any parameters so far." # Setting parameters from an existing model) e2.set_parameters(e.get_parameters()) assert e.get_parameters() == e2.get_parameters(), "They have the same parameters." assert '@cb/conv/W' in e.get_parameters() assert '@cb/conv/W' in e2.get_parameters() @pytest.mark.parametrize("module_creator", [ModuleCreator(TSTNetAbnormal(), [(4, 3, 32, 32), (4, 3, 32, 32)])]) def test_unsupported_module_definition(module_creator): # Since we use global graphdef, we should reset it beforehand # This is already done in test fixture. # nn.graph_def.reset_default_graph() # get module from test parameters module = module_creator.module # create variable inputs and initialized by random value variable_inputs = module_creator.get_variable_inputs() # create reference network by passing in variable inputs ref_outputs = module(*variable_inputs) # create proto variable inputs