Ejemplo n.º 1
0
    def __init__(self, option, src_model_file):
        torch._C.Node.__getitem__ = _node_getitem
        self._param_converts = (
            NodeKind.Constant,
            NodeKind.List,
            NodeKind.Size,
            NodeKind.NumToTensor,
            NodeKind.Int,
        )

        self._option = option
        self._converter_info = dict()
        self._mace_net_def = mace_pb2.NetDef()
        ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW)
        ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NCHW)
        ConverterUtil.set_framework_type(self._mace_net_def,
                                         FrameworkType.PYTORCH.value)
        self._op_converters = {
            NodeKind.AdaptiveAvgPool2D: self.convert_pool,
            NodeKind.Add: self.convert_add,
            NodeKind.Add_: self.convert_add,
            NodeKind.Addmm: self.convert_addmm,
            NodeKind.AvgPool2D: self.convert_pool,
            NodeKind.BatchNorm: self.convert_batch_norm,
            NodeKind.Cat: self.convert_cat,
            NodeKind.Convolution: self.convert_conv2d,
            NodeKind.Dropout: self.convert_dropout,
            NodeKind.Flatten: self.convert_flatten,
            NodeKind.HardTanh_: self.convert_hardtanh,
            NodeKind.HardTanh: self.convert_hardtanh,
            NodeKind.Matmul: self.convert_matmul,
            NodeKind.MaxPool2D: self.convert_pool,
            NodeKind.Relu: self.convert_relu,
            NodeKind.Relu_: self.convert_relu,
            NodeKind.Reshape: self.convert_reshape,
            NodeKind.T: self.convert_t,
        }
        self._loaded_model = torch.jit.load(src_model_file)
        self._loaded_model.eval()
        self._graph, self._params_dict = self.model_to_graph()
        self._output_node_name = list(self._graph.outputs())[0].debugName()
        self._output_value_type = list(self._graph.outputs())[0].type()
        mace_check(
            isinstance(self._output_value_type,
                       (ValueType.TensorType, ValueType.ListType,
                        ValueType.TupleType)),
            'return type {} not supported'.format(self._output_value_type))
        self._node_map = {}
        self.init_output_shape_cache()
Ejemplo n.º 2
0
    def __init__(self, option, src_model_file):
        # Keep in lexicographical order
        self._op_converters = {
            TFOpType.Abs.name: self.convert_elementwise,
            TFOpType.Add.name: self.convert_add,
            TFOpType.AddV2.name: self.convert_add,
            TFOpType.ArgMax.name: self.convert_argmax,
            TFOpType.AvgPool.name: self.convert_pooling,
            TFOpType.BatchMatMul.name: self.convert_matmul,
            TFOpType.BatchMatMulV2.name: self.convert_matmul,
            TFOpType.BatchToSpaceND.name: self.convert_space_batch,
            TFOpType.BiasAdd.name: self.convert_biasadd,
            TFOpType.Cast.name: self.convert_cast,
            TFOpType.ConcatV2.name: self.convert_concat,
            TFOpType.Const.name: self.convert_nop,
            TFOpType.Conv2D.name: self.convert_conv2d,
            TFOpType.Conv2DBackpropInput.name: self.convert_conv2d,
            TFOpType.Cumsum.name: self.convert_cumsum,
            TFOpType.DepthwiseConv2dNative.name: self.convert_conv2d,
            TFOpType.DepthToSpace.name: self.convert_space_depth,
            TFOpType.Div.name: self.convert_elementwise,
            TFOpType.Elu.name: self.convert_activation,
            TFOpType.Equal.name: self.convert_elementwise,
            TFOpType.ExpandDims.name: self.convert_expand_dims,
            TFOpType.ExtractImagePatches.name:
            self.convert_extract_image_patches,
            TFOpType.FakeQuantWithMinMaxVars.name: self.convert_fake_quantize,
            TFOpType.FakeQuantWithMinMaxArgs.name: self.convert_fake_quantize,
            TFOpType.Fill.name: self.convert_fill,
            TFOpType.FloorDiv.name: self.convert_elementwise,
            TFOpType.FusedBatchNorm.name: self.convert_fused_batchnorm,
            TFOpType.FusedBatchNormV2.name: self.convert_fused_batchnorm,
            TFOpType.FusedBatchNormV3.name: self.convert_fused_batchnorm,
            TFOpType.Gather.name: self.convert_gather,
            TFOpType.GatherV2.name: self.convert_gather,
            TFOpType.Identity.name: self.convert_identity,
            TFOpType.LeakyRelu.name: self.convert_activation,
            TFOpType.MatMul.name: self.convert_matmul,
            TFOpType.Max.name: self.convert_reduce,
            TFOpType.Maximum.name: self.convert_elementwise,
            TFOpType.MaxPool.name: self.convert_pooling,
            TFOpType.Mean.name: self.convert_reduce,
            TFOpType.Min.name: self.convert_reduce,
            TFOpType.Minimum.name: self.convert_elementwise,
            TFOpType.MirrorPad.name: self.convert_pad,
            TFOpType.Mul.name: self.convert_elementwise,
            TFOpType.Neg.name: self.convert_elementwise,
            TFOpType.NotEqual.name: self.convert_elementwise,
            TFOpType.OneHot.name: self.convert_one_hot,
            TFOpType.Pack.name: self.convert_stack,
            TFOpType.Pad.name: self.convert_pad,
            TFOpType.PadV2.name: self.convert_pad,
            TFOpType.Placeholder.name: self.convert_nop,
            TFOpType.Pow.name: self.convert_elementwise,
            TFOpType.Prod.name: self.convert_reduce,
            TFOpType.Sub.name: self.convert_elementwise,
            TFOpType.RealDiv.name: self.convert_elementwise,
            TFOpType.SquaredDifference.name: self.convert_elementwise,
            TFOpType.Square.name: self.convert_elementwise,
            TFOpType.Rsqrt.name: self.convert_elementwise,
            TFOpType.Relu.name: self.convert_activation,
            TFOpType.Relu6.name: self.convert_activation,
            TFOpType.Tanh.name: self.convert_activation,
            TFOpType.Reshape.name: self.convert_reshape,
            TFOpType.ResizeBicubic.name: self.convert_resize_bicubic,
            TFOpType.ResizeBilinear.name: self.convert_resize_bilinear,
            TFOpType.ResizeNearestNeighbor.name:
            self.convert_resize_nearest_neighbor,
            TFOpType.ReverseV2.name: self.convert_reverse,
            TFOpType.Select.name: self.convert_select,
            TFOpType.Shape.name: self.convert_shape,
            TFOpType.Sigmoid.name: self.convert_activation,
            TFOpType.Sign.name: self.convert_elementwise,
            TFOpType.Slice.name: self.convert_slice,
            TFOpType.Softmax.name: self.convert_softmax,
            TFOpType.SpaceToBatchND.name: self.convert_space_batch,
            TFOpType.SpaceToDepth.name: self.convert_space_depth,
            TFOpType.Split.name: self.convert_split,
            TFOpType.SplitV.name: self.convert_splitv,
            TFOpType.Sqrt.name: self.convert_elementwise,
            TFOpType.Squeeze.name: self.convert_squeeze,
            TFOpType.Stack.name: self.convert_stack,
            TFOpType.StridedSlice.name: self.convert_stridedslice,
            TFOpType.Sum.name: self.convert_reduce,
            TFOpType.Tile.name: self.convert_tile,
            TFOpType.Transpose.name: self.convert_transpose,
            TFOpType.Unpack.name: self.convert_unstack,
            TFOpType.Unstack.name: self.convert_unstack,
        }
        self._option = option
        self._mace_net_def = mace_pb2.NetDef()
        ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.HWIO)
        ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NHWC)
        ConverterUtil.set_framework_type(self._mace_net_def,
                                         FrameworkType.TENSORFLOW.value)

        # import tensorflow graph
        tf_graph_def = tf.GraphDef()
        with tf.gfile.Open(src_model_file, 'rb') as f:
            tf_graph_def.ParseFromString(f.read())

        self._placeholders = {}
        self._skip_tensor = set()
        self._output_shape = {}

        print("Run transform_graph: %s" % TFTransformGraphOptions)
        try:
            print("output keys: ", option.output_nodes.keys())
            transformed_graph_def = TransformGraph(tf_graph_def,
                                                   option.input_nodes.keys(),
                                                   option.output_nodes.keys(),
                                                   TFTransformGraphOptions)
        except Exception as ex:
            print("Failed to transform graph using tf tool: %s" % ex)
            transformed_graph_def = tf_graph_def

        # To check optimized model, uncomment following code.
        # tf.io.write_graph(
        #     transformed_graph_def,
        #     ".",
        #     os.path.basename(src_model_file)[:-3] + "_opt.pb",
        #     as_text=False
        # )

        self.add_shape_info(transformed_graph_def)

        # reset default graph to clear earlier import
        tf.reset_default_graph()
        with tf.Session() as session:
            with session.graph.as_default() as graph:
                tf.import_graph_def(transformed_graph_def, name='')
                self._tf_graph = graph
                self.update_output_shapes(session)

        # we have polluted graph with 'shape' ops, so reset it and reload it
        # again
        tf.reset_default_graph()
        with tf.Session() as session:
            with session.graph.as_default() as graph:
                tf.import_graph_def(transformed_graph_def, name='')
                self._tf_graph = graph
Ejemplo n.º 3
0
    def __init__(self, option, src_model_file, src_weight_file):
        self._op_converters = {
            'Input': self.convert_nop,
            'Convolution': self.convert_conv2d,
            'Deconvolution': self.convert_deconv2d,
            'Eltwise': self.convert_elementwise,
            'Add': self.convert_add,
            'ReLU': self.convert_activation,
            'ReLU6': self.convert_activation,
            'TanH': self.convert_activation,
            'Sigmoid': self.convert_activation,
            'PReLU': self.convert_activation,
            'Clip': self.convert_activation,
            'ELU': self.convert_activation,
            'Pooling': self.convert_pooling,
            'Concat': self.convert_concat,
            'Slice': self.convert_slice,
            'Softmax': self.convert_softmax,
            'InnerProduct': self.convert_fully_connected,
            'Interp': self.convert_interp,
            'BatchNorm': self.convert_folded_batchnorm,
            'GroupNorm': self.convert_group_norm,
            'Crop': self.convert_crop,
            'Scale': self.convert_scale,
            'ShuffleChannel': self.convert_channel_shuffle,
            'Permute': self.convert_permute,
            'Flatten': self.convert_flatten,
            'PriorBox': self.convert_prior_box,
            'Reshape': self.convert_reshape,
            'L2Normalization': self.convert_lpnorm,
            'L1Normalization': self.convert_lpnorm,
            'MVN': self.convert_MVN,
            'Bias': self.convert_bias,
            'ArgMax': self.convert_argmax,
            'ResizeNearest': self.convert_resize_nearest,
            'NonlocalReshape': self.convert_nonlocal_reshape,
            'MatMul': self.convert_matmul,
        }
        self._option = option
        self._mace_net_def = mace_pb2.NetDef()
        ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW)
        ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NCHW)
        ConverterUtil.set_framework_type(self._mace_net_def,
                                         FrameworkType.CAFFE.value)
        self._caffe_net = CaffeNet()
        self._caffe_layers = caffe_pb2.NetParameter()
        caffe_weights = caffe_pb2.NetParameter()

        # parse prototxt
        with open(src_model_file, 'r') as f:
            google.protobuf.text_format.Merge(str(f.read()),
                                              self._caffe_layers)
            self.filter_test_layers(self._caffe_layers)
            for layer in self._caffe_layers.layer:
                self._caffe_net.add_layer(layer)

        # parse model weight
        with open(src_weight_file, 'rb') as f:
            caffe_weights.ParseFromString(f.read())
            self.filter_test_layers(caffe_weights)
            for weight in caffe_weights.layer:
                self._caffe_net.add_blob(weight)

        self._skip_ops = []