def test_caffe_same_name_layer(self): proto = caffe_pb2.NetParameter() text_format.Merge(proto_str_multi_input + proto_same_name_layers, proto) graph = Graph() caffe_pb_to_nx(graph, proto, None) # 6 nodes because: 2 inputs + 2 convolutions + 2 identity nodes used as fake outputs np.testing.assert_equal(len(graph.nodes()), 6)
def test_caffe_pb_to_nx_one_input(self): proto = caffe_pb2.NetParameter() text_format.Merge(proto_str_one_input, proto) input_shapes = caffe_pb_to_nx(Graph(), proto, None) expected_input_shapes = {'Input0': np.array([1, 3, 224, 224])} for i in expected_input_shapes: np.testing.assert_array_equal(input_shapes[i], expected_input_shapes[i])
def test_caffe_pb_to_multi_input(self): proto = caffe_pb2.NetParameter() text_format.Merge(proto_str_multi_input + layer_proto_str, proto) input_shapes = caffe_pb_to_nx(Graph(), proto, None) expected_input_shapes = { 'data': np.array([1, 3, 224, 224]), 'data1': np.array([1, 3]) } for i in expected_input_shapes: np.testing.assert_array_equal(input_shapes[i], expected_input_shapes[i])
def load(self, graph: Graph): argv = graph.graph['cmd_params'] caffe_pb2 = loader.import_caffe_pb2(argv.caffe_parser_path) proto, model = loader.load_caffe_proto_model(caffe_pb2, argv.input_proto, argv.input_model) update_extractors_with_extensions( caffe_type_extractors, argv.disable_omitting_optional if hasattr( argv, 'disable_omitting_optional') else False, argv.disable_flattening_optional_params if hasattr( argv, 'disable_flattening_optional_params') else False) try: original_shapes = loader.caffe_pb_to_nx(graph, proto, model) except ValueError as e: raise Error( 'Invalid prototxt file: value error {}. ' + refer_to_faq_msg(11), str(e)) from e graph.check_empty_graph('load_caffe_proto_model') graph.__setattr__('proto_path', argv.input_proto) graph.__setattr__('caffemodel_path', argv.input_model) graph.__setattr__('name', getattr(proto, 'name', None) or argv.model_name) graph.graph['layout'] = 'NCHW' graph.graph['fw'] = 'caffe' graph.graph['original_shapes'] = original_shapes graph.graph['caffe_pb2'] = caffe_pb2 custom_layers_map = custom_layers_mapping.load_layers_xml(argv.k) custom_layers_mapping.update_extractors( caffe_type_extractors, custom_layers_map, argv.disable_omitting_optional if hasattr( argv, 'disable_omitting_optional') else False, argv.enable_flattening_nested_params if hasattr( argv, 'enable_flattening_nested_params') else False) extract_node_attrs( graph, lambda node: caffe_extractor( node, check_for_duplicates(caffe_type_extractors))) send_op_names_info('caffe', graph) send_shapes_info('caffe', graph)