Exemplo n.º 1
0
class MXNetEmitter(Emitter):

    dtype_map = {
        graph_pb2.DT_FLOAT16    : "float16",
        graph_pb2.DT_FLOAT32    : "float32",
        graph_pb2.DT_FLOAT64    : "float64",
        graph_pb2.DT_INT32      : "int32",
        graph_pb2.DT_UINT8      : "uint8"
    }

    activation_map = {
        "relu"      : "Relu",
        "sigmoid"   : "Sigmoid",
        "tanh"      : "Tanh",
        "elu"       : "Elu"
    }

    transpose_map = {
        1 : 2,
        2 : 3,
       -1 : 1
    }

    channels_last = ['NDHWC', 'NHWC']

    def __init__(self, model):
        super(MXNetEmitter, self).__init__()
        from six import string_types as _string_types

        if isinstance(model, _string_types):
            network_path = model
            self.weight_loaded = False
        elif len(model) == 3:
            network_path = model[0]
            weight_path = model[1]
            self.output_weights_file = model[2]
            self.weights = np.load(weight_path).item()
            self.weight_loaded = True
            self.output_weights = dict()
        else:
            raise ValueError("the # of input arguments [{}] is not supported" % len(model))

        self.IR_graph = IRGraph(network_path)
        self.IR_graph.build()


    @property
    def header_code(self):
        return """import mxnet as mx
import numpy as np
import math

# mxnet-cpu only support channel first, default convert the model and weight as channel first

def RefactorModel():
"""


    def gen_code(self, phase):
        self.IR_layer_map = dict()
        self.add_body(0, self.header_code)
        for layer in self.IR_graph.topological_sort:
            self.IR_layer_map[layer] = self.IR_graph.get_node(layer)

        shape = dict()
        for layer in self.IR_graph.topological_sort:
            current_node = self.IR_graph.get_node(layer)
            node_type = current_node.type


            if len(current_node.in_edges) == 0:
                current_node.in_edges.append('data')

            if node_type.lower() in MXNetEmitter.activation_map:
                func = getattr(self, "emit_Activation")
                line = func(current_node, MXNetEmitter.activation_map[node_type.lower()].lower())
                self.add_body(1, line)

            elif hasattr(self, "emit_" + node_type):
                func = getattr(self, "emit_" + node_type)
                line = func(current_node)
                if line != None:
                    self.add_body(1, line)
            else:
                print("MXNet Emitter has not supported operator [%s]." % (node_type))
                self.emit_UNKNOWN(current_node)

            if node_type == "DataInput":
                cur_shape = list()
                first = True
                for dim in current_node.IR_layer.attr["shape"].shape.dim:
                    if dim.size == -1 and first:
                        cur_shape.append(1)
                        print("Detect input layer [{}] using infer batch size, set it as default value [1]".format(current_node.name))
                    else:
                        if dim.size == -1:
                            print("Warning: user should change input size manually")
                        cur_shape.append(dim.size)
                    first = False

                cur_shape.insert(1, cur_shape.pop())
                shape[current_node.name] = ', '.join('%s' % i for i in cur_shape)


        if self.weight_loaded:
            fullpath = os.path.abspath(self.output_weights_file)
            dirname = os.path.dirname(fullpath)
            if not os.path.exists(dirname):
                os.makedirs(dirname)
            with open(self.output_weights_file, 'wb') as outfile:
                np.save(outfile, self.output_weights)

        comment = "\n    # if a GPU is available, change mx.cpu() to mx.gpu()"
        last_line = "{:<15} = mx.mod.Module(symbol = {}, context = mx.cpu(), data_names = ['{}'])".format(
            "model",
            ', '.join([self.IR_graph.get_node(name).real_variable_name for name in self.IR_graph.output_layers if self.IR_graph.get_node(name).type != 'Pack']),
            ', '.join([self.IR_graph.get_node(name).real_variable_name for name in self.IR_graph.input_layers if self.IR_graph.get_node(name).type != 'Const']))

        self.add_body(1, comment)
        self.add_body(1, last_line)
        self.add_body(1, "return model")

        weight_code = ""
        if not self.weight_loaded:
            weight_code += "# emitter does not detect any import weights, you may generate weights file manually\n"

        weight_code += self.gen_weight_code(shape, phase)

        main_code = "if __name__ == '__main__':\n    model = RefactorModel()\n"
        if self.weight_loaded:
            main_code += "    # remember to adjust params path\n    model = deploy_weight(model, '{}')\n".format(self.output_weights_file)

        if phase == 'train':
            train_code = """def train(model):
    import logging
    logging.getLogger().setLevel(logging.DEBUG)
    model.fit(train_iter, # train data
            eval_data = val_iter, # validation data
            optimizer = 'sgd', # Defaults to 'sgd'
            optimizer_params = {'learning_rate':0.01}, # use fixed learning rate
            eval_metric = 'acc', # report accuracy during training, other possible predefined metrics are: 'ce', 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'
            batch_end_callback = mx.callback.Speedometer(batch_size, 100), # output progress for each 100 data batches
            num_epoch = 10) # train for at most 10 dataset passes\n\n
"""
            code = self.body_code + weight_code + train_code + main_code
        else:
            test_code = """from collections import namedtuple
Batch = namedtuple('Batch', ['data'])


def get_image(url, show=False):
    import cv2
    # download and show the image
    fname = mx.test_utils.download(url)
    img = cv2.cvtColor(cv2.imread(fname), cv2.COLOR_BGR2RGB)
    if img is None:
        return None
    if show:
        import matplotlib.pyplot as plt
        plt.imshow(img)
        plt.axis('off')
    # convert into format (batch, RGB, width, height)
    img = cv2.resize(img, (224, 224))
    img = np.swapaxes(img, 0, 2)
    img = np.swapaxes(img, 1, 2)
    img = img[np.newaxis, :]
    return img


def predict(model, labels, url):
    # to show the image, change the argument show into True
    img = get_image(url, show = False)
    # compute the predict probabilities
    model.forward(Batch([mx.nd.array(img)]))
    prob = model.get_outputs()[0].asnumpy()
    # print the top-5
    prob = np.squeeze(prob)
    a = np.argsort(prob)[::-1]
    for i in a[0:5]:
        print('prbability = %f, class = %s' %(prob[i], labels[i]))\n\n
"""

            main_code += """
    # # call function predict
    # with open('synset.txt', 'r') as f:
    #     labels = [l.rstrip() for l in f]
    # predict(model, labels, 'http://writm.com/wp-content/uploads/2016/08/Cat-hd-wallpapers.jpg')
"""

            code = self.body_code + weight_code + test_code + main_code

        return code


    def gen_weight_code(self, shape, phase):
        str = "def deploy_weight(model, weight_file):\n"
        str += """
    if weight_file == None:
        return

    try:
        weights_dict = np.load(weight_file).item()
    except:
        weights_dict = np.load(weight_file, encoding='bytes').item()

    arg_params = dict()
    aux_params = dict()
    for weight_name, weight_data in weights_dict.items():
        weight_name = str(weight_name)
        if "moving" in weight_name:
            aux_params[weight_name] = mx.nd.array(weight_data)
        else:
            arg_params[weight_name] = mx.nd.array(weight_data)

"""
        if phase == 'train':
            str += "    model.bind(for_training = True, data_shapes = ["
        else:
            str += "    model.bind(for_training = False, data_shapes = ["
        first = True
        for k, v in shape.items():
            if not first:
                str += ", "
            str += "('" + k + "', " + "(" + v + "))"
            first = False
        str += "])\n"
        str += "    model.set_params(arg_params = arg_params, aux_params = aux_params, allow_missing = True)\n\n    return model\n\n\n"
        return str


    @staticmethod
    def calculate_same_pad(data_shape, kernel, stride):
        if (data_shape % stride == 0):
            pad = max(kernel - stride, 0)
        else:
            pad = max(kernel - (data_shape % stride), 0)
        if pad % 2 == 0:
            return False, pad
        else:
            return True, pad


    @staticmethod
    def transfer_pad(pad_list):
        defuse_pad = False
        pad = list()

        assert len(pad_list) % 2 == 0
        mid = int(len(pad_list)/2)
        pad_first = pad_list[1:mid-1]
        pad_second = pad_list[mid+1:-1]

        for i in range(0, mid-2):
            if not pad_first[i] == pad_second[i]:
                defuse_pad = True

        if defuse_pad:
            pad.extend([0] * 4)
            for i in range(0, mid-2):
                pad.extend([pad_first[i], pad_second[i]])
        else:
            pad = pad_first

        return defuse_pad, pad


    @staticmethod
    def transpose(data, dim):
        if dim == 1:
            data = data.transpose((2, 1, 0))
        elif dim == 2:
            data = data.transpose((3, 2, 0, 1))
        elif dim == 3:
            data = data.transpose((4, 3, 0, 1, 2))
        else:
            raise ValueError("The weight of dim {} cannot transpose" % dim)

        return data


    def set_pad(self, IR_node, code, pad, _max_pool):
        if _max_pool:
            constant_value = "float('-inf')"
        else:
            constant_value = "0.0"

        code = "{:<15} = mx.sym.pad(data = {}, mode = 'constant', pad_width={}, constant_value = {}, name = '{}')".format(
                IR_node.variable_name + "_pad",
                self.parent_variable_name(IR_node),
                tuple(pad),
                constant_value,
                IR_node.name + "_pad")

        for e in IR_node.in_edges:
            if e == 'data':
                continue
            self.IR_layer_map[e].out_edges = [x if not self.IR_layer_map[x].name == IR_node.variable_name else IR_node.variable_name + "_pad" for x in self.IR_layer_map[e].out_edges]

        return code


    def emit_UNKNOWN(self, IR_node):
        print(IR_node.name)


    def emit_FullyConnected(self, IR_node):
        if self.weight_loaded:
            weight_dict = self.weights[IR_node.name]
            parent = self.IR_graph.get_parent(IR_node.name, [0])
            while parent.type == "Flatten" or parent.type == 'Dropout':
                parent = self.IR_graph.get_parent(parent.name, [0])
            dim = len(parent.layer.attr['_output_shapes'].list.shape[0].dim)
            if dim > 2:
                original_dims = weight_dict['weights'].shape
                dims = [i.size for i in parent.layer.attr['_output_shapes'].list.shape[0].dim[1:]] + [-1]
                weight_dict['weights'] = np.reshape(weight_dict['weights'], dims)
                weight_dict['weights'] = np.transpose(weight_dict['weights'], [dim - 2] + list(range(0, dim - 2)) + [dim - 1])
                weight_dict['weights'] = np.reshape(weight_dict['weights'], original_dims)
            self.output_weights[IR_node.name + "_weight"] = weight_dict['weights'].transpose((1, 0))

        num_hidden = IR_node.IR_layer.attr["units"].i
        no_bias = not IR_node.IR_layer.attr["use_bias"].b
        if not no_bias and self.weight_loaded:
            self.output_weights[IR_node.name + "_bias"] = weight_dict['bias']

        code = "{:<15} = mx.sym.FullyConnected(data = {}, num_hidden = {}, no_bias = {}, name = '{}')".format(
                IR_node.variable_name,
                self.parent_variable_name(IR_node),
                num_hidden,
                no_bias,
                IR_node.name)

        return code


    def _emit_convolution(self, IR_node, pattern):
        if self.weight_loaded:
            weight_dict = self.weights[IR_node.name]
            weights = weight_dict['weights']

        dim = len(IR_node.IR_layer.attr["kernel_shape"].list.i) - 2

        kernel = list()
        for idx in range(0, dim):
            kernel.append(IR_node.IR_layer.attr["kernel_shape"].list.i[idx])

        stride = list()
        for e in IR_node.IR_layer.attr["strides"].list.i[1:-1]:
            stride.append(e)

        dilate = list()
        for e in IR_node.IR_layer.attr["dilations"].list.i[1:-1]:
            dilate.append(e)
        if dilate == []: dilate = [1, 1]
        dilate = ', '.join('%s' % i for i in dilate)

        defuse_pad = False
        pad = list()
        if "pads" in IR_node.IR_layer.attr:
            output_shape = list()
            for e in IR_node.IR_layer.attr["_output_shapes"].list.shape[0].dim:
                output_shape.append(e.size)

            # print("Warning: MXNet Convolution Layer pad does not match IR Convolution Layer pad")
            defuse_pad, pad = MXNetEmitter.transfer_pad(IR_node.IR_layer.attr["pads"].list.i)

        num_filter = 0
        if pattern == "Deconvolution":
            num_filter = IR_node.IR_layer.attr["kernel_shape"].list.i[-2]
        else:
            num_filter = IR_node.IR_layer.attr["kernel_shape"].list.i[-1]

        use_bias = IR_node.get_attr('use_bias', False)
        if use_bias and self.weight_loaded:
            self.output_weights[IR_node.name + "_bias"] = weight_dict['bias']

        if pattern == "DepthwiseConv":
            num_group = IR_node.IR_layer.attr["kernel_shape"].list.i[-2]
            num_filter = num_filter * num_group
            pattern = "Convolution"
            if self.weight_loaded:
                weights = np.swapaxes(weights, -1, -2)

        else:
            num_group = IR_node.get_attr('group', 1)

        # layout = IR_node.IR_layer.attr["data_format"].s
        if dim == 1:
            layout = 'NCW'
        elif dim == 2:
            layout = 'NCHW'
        elif dim == 3:
            layout = 'NCDHW'

        if self.weight_loaded:
            # if layout not in MXNetEmitter.channels_last:
            weights = MXNetEmitter.transpose(weights, dim)
            self.output_weights[IR_node.name + "_weight"] = weights

        code = ""
        if not defuse_pad:
            code += "{:<15} = mx.sym.{}(data={}, kernel={}, stride={}, dilate = ({}), pad={}, num_filter = {}, num_group = {}, no_bias = {}, layout = '{}', name = '{}')".format(
                IR_node.variable_name,
                pattern,
                self.parent_variable_name(IR_node),
                tuple(kernel),
                tuple(stride),
                dilate,
                tuple(pad),
                num_filter,
                num_group,
                not use_bias,
                layout,
                IR_node.name)
        else:
            code += self.set_pad(IR_node, code, pad, False)
            code += "\n    {:<15} = mx.sym.{}(data={}, kernel={}, stride={}, dilate = ({}), num_filter = {}, num_group = {}, no_bias = {}, layout = '{}', name = '{}')".format(
                IR_node.variable_name,
                pattern,
                IR_node.variable_name + "_pad",
                tuple(kernel),
                tuple(stride),
                dilate,
                num_filter,
                num_group,
                not use_bias,
                layout,
                IR_node.name)

        return code


    def emit_Conv(self, IR_node):
        return self._emit_convolution(IR_node, "Convolution")


    def emit_DepthwiseConv(self, IR_node):
        return self._emit_convolution(IR_node, "DepthwiseConv")


    def emit_ConvTranspose(self, IR_node):
        return self._emit_convolution(IR_node, "Deconvolution")


    def emit_DataInput(self, IR_node):
        shape = list()
        shape.extend(IR_node.IR_layer.attr["shape"].list.i)

        code = "{:<15} = mx.sym.var('{}')".format(IR_node.variable_name, IR_node.name)
        return code


    # Add LeakyReLU Elu(slope not support)
    def emit_Activation(self, IR_node, act_type):

        act_type = act_type
        func_name = ""

        if act_type == "elu":
            func_name = "LeakyReLU"
        else:
            func_name = "Activation"

        code = "{:<15} = mx.sym.{}(data = {}, act_type = '{}', name = '{}')".format(
                IR_node.variable_name,
                func_name,
                self.parent_variable_name(IR_node),
                act_type,
                IR_node.name)

        return code


    def emit_BatchNorm(self, IR_node):
        IR_node_after = self.IR_graph.get_son(IR_node.name, [0])
        if IR_node_after.type == 'Scale':
            if self.weight_loaded:
                weight_dict = self.weights[IR_node.name]
                weight_dict_scale = self.weights[IR_node_after.name]

            # axis = IR_node.IR_layer.attr["axis"].i
            axis = 1
            eps = IR_node.IR_layer.attr["epsilon"].f
            momentum = IR_node.IR_layer.attr["momentum"].f

            fix_gamma = not IR_node.IR_layer.attr["scale"].b

            if self.weight_loaded:
                if not fix_gamma:
                    self.output_weights[IR_node.name + "_gamma"] = np.multiply(weight_dict['scale'], weight_dict_scale['scale'])
                self.output_weights[IR_node.name + "_beta"] = np.multiply(weight_dict['bias'], weight_dict_scale['scale']) + weight_dict_scale['bias']

            # not supported yet
            use_global_stats = "False"
            if self.weight_loaded:
                self.output_weights[IR_node.name + "_moving_var"] = weight_dict['var']
                self.output_weights[IR_node.name + "_moving_mean"] = weight_dict['mean']

            code = "{:<15} = mx.sym.BatchNorm(data = {}, axis = {}, eps = {}, momentum = {}, fix_gamma = {}, use_global_stats = {}, name = '{}')".format(
                    IR_node.variable_name,
                    self.parent_variable_name(IR_node),
                    axis,
                    eps,
                    momentum,
                    fix_gamma,
                    use_global_stats,
                    IR_node.name)

            return code

        else:
            if self.weight_loaded:
                weight_dict = self.weights[IR_node.name]

            # axis = IR_node.IR_layer.attr["axis"].i
            axis = 1
            eps = IR_node.IR_layer.attr["epsilon"].f
            momentum = IR_node.IR_layer.attr["momentum"].f

            fix_gamma = not IR_node.IR_layer.attr["scale"].b

            if self.weight_loaded:
                if not fix_gamma:
                    self.output_weights[IR_node.name + "_gamma"] = weight_dict['scale']
                self.output_weights[IR_node.name + "_beta"] = weight_dict['bias']

            # not supported yet
            use_global_stats = "False"
            if self.weight_loaded:
                self.output_weights[IR_node.name + "_moving_var"] = weight_dict['var']
                self.output_weights[IR_node.name + "_moving_mean"] = weight_dict['mean']

            code = "{:<15} = mx.sym.BatchNorm(data = {}, axis = {}, eps = {}, momentum = {}, fix_gamma = {}, use_global_stats = {}, name = '{}')".format(
                    IR_node.variable_name,
                    self.parent_variable_name(IR_node),
                    axis,
                    eps,
                    momentum,
                    fix_gamma,
                    use_global_stats,
                    IR_node.name)

            return code

    def emit_Scale(self, IR_node):
        if self.weight_loaded:
            weight_dict = self.weights[IR_node.name]

        # axis = IR_node.IR_layer.attr["axis"].i
        axis = 1
        eps = 0.0
        momentum = 0.0

        fix_gamma = not IR_node.IR_layer.attr["scale"].b

        if self.weight_loaded:
            if not fix_gamma:
                self.output_weights[IR_node.name + "_gamma"] = weight_dict['scale']
            self.output_weights[IR_node.name + "_beta"] = weight_dict['bias']

        # not supported yet
        use_global_stats = "False"
        if self.weight_loaded:
            self.output_weights[IR_node.name + "_moving_var"] = weight_dict['scale_var']
            self.output_weights[IR_node.name + "_moving_mean"] = weight_dict['scale_mean']

        code = "{:<15} = mx.sym.BatchNorm(data = {}, axis = {}, eps = {}, momentum = {}, fix_gamma = {}, use_global_stats = {}, name = '{}')".format(
                IR_node.variable_name,
                self.parent_variable_name(IR_node),
                axis,
                eps,
                momentum,
                fix_gamma,
                use_global_stats,
                IR_node.name)

        return code



    def emit_Pool(self, IR_node):

        global_pool = IR_node.IR_layer.attr["global_pooling"].b

        kernel = list()
        if global_pool:
            kernel = [1] * (len(IR_node.IR_layer.attr["strides"].list.i) - 2)
        else:
            for e in IR_node.IR_layer.attr["kernel_shape"].list.i[1:-1]:
                kernel.append(e)

        pool_type = IR_node.get_attr('pooling_type').lower()

        stride = list()
        for e in IR_node.IR_layer.attr["strides"].list.i[1:-1]:
            stride.append(e)

        defuse_pad = False
        pad = list()
        if "pads" in IR_node.IR_layer.attr:
            output_shape = list()
            for e in IR_node.IR_layer.attr["_output_shapes"].list.shape[0].dim:
                output_shape.append(e.size)

            # print("Warning: MXNet Pooling Layer pad does not match IR Pooling Layer pad")
            defuse_pad, pad = MXNetEmitter.transfer_pad(IR_node.IR_layer.attr["pads"].list.i)
        code = ""
        if not defuse_pad:
            code += "{:<15} = mx.sym.Pooling(data = {}, global_pool = {}, kernel={}, pool_type = '{}', stride={}, pad={}, name = '{}')".format(
                    IR_node.variable_name,
                    self.parent_variable_name(IR_node),
                    global_pool,
                    tuple(kernel),
                    pool_type,
                    tuple(stride),
                    tuple(pad),
                    IR_node.name)
        else:
            code += self.set_pad(IR_node, code, pad, pool_type == "max")
            code += "\n    {:<15} = mx.sym.Pooling(data = {}, global_pool = {}, kernel={}, pool_type = '{}', stride={}, name = '{}')".format(
                    IR_node.variable_name,
                    IR_node.variable_name + "_pad",
                    global_pool,
                    tuple(kernel),
                    pool_type,
                    tuple(stride),
                    IR_node.name)

        return code


    def emit_SoftmaxOutput(self, IR_node):

        code = "{:<15} = mx.sym.SoftmaxOutput(data = {}, name = 'softmax')".format(
            IR_node.variable_name,
            self.parent_variable_name(IR_node)
        )

        return code


    def emit_Softmax(self, IR_node):

        code = ""

        if len(IR_node.out_edges) == 0:
            code = "{:<15} = mx.sym.SoftmaxOutput(data = {}, name = 'softmax')".format(
                    IR_node.variable_name,
                    self.parent_variable_name(IR_node))
        else:
            axis = IR_node.IR_layer.attr["dim"].i
            code = "{:<15} = mx.sym.softmax(data = {}, axis = {}, name = '{}')".format(
                    IR_node.variable_name,
                    self.parent_variable_name(IR_node),
                    axis,
                    IR_node.name)

        return code


    def emit_Squeeze(self, IR_node):
        return self.emit_Flatten(IR_node)


    # def emit_ConvTranspose(self, IR_node):
    #     if self.weight_loaded:
    #         weight_dict = self.weights[IR_node.name]
    #         weights = weight_dict['weights']

    #     dim = len(IR_node.IR_layer.attr["kernel_shape"].list.i) - 2

    #     kernel = list()
    #     for idx in range(0, dim):
    #         kernel.append(IR_node.IR_layer.attr["kernel_shape"].list.i[idx])

    #     stride = list()
    #     for e in IR_node.IR_layer.attr["strides"].list.i[1:-1]:
    #         stride.append(e)

    #     dilate = list()
    #     for e in IR_node.IR_layer.attr["dilations"].list.i[1:-1]:
    #         dilate.append(e)
    #     dilate = ', '.join('%s' % i for i in dilate)

    #     defuse_pad = False
    #     pad = list()
    #     if "pads" in IR_node.IR_layer.attr:
    #         output_shape = list()
    #         for e in IR_node.IR_layer.attr["_output_shapes"].list.shape[0].dim:
    #             output_shape.append(e.size)

    #         # print("Warning: MXNet Deconvolution Layer pad does not match IR Deconvolution Layer pad")
    #         defuse_pad, pad = MXNetEmitter.transfer_pad(IR_node.IR_layer.attr["pads"].list.i)
    #     pad = ', '.join('%s' % i for i in pad)

    #     kernel = ', '.join('%s' % i for i in kernel)
    #     stride = ', '.join('%s' % i for i in stride)

    #     num_filter = IR_node.IR_layer.attr["kernel_shape"].list.i[-2]
    #     no_bias = not IR_node.IR_layer.attr["use_bias"].b
    #     if not no_bias and self.weight_loaded:
    #         self.output_weights[IR_node.replace_scope(IR_node.name) + "_bias"] = weight_dict['bias']

    #     # layout = IR_node.IR_layer.attr["data_format"].s
    #     if dim == 1:
    #         layout = 'NCW'
    #     elif dim == 2:
    #         layout = 'NCHW'
    #     elif dim == 3:
    #         layout = 'NCDHW'

    #     if self.weight_loaded:
    #         # if layout not in MXNetEmitter.channels_last:
    #         weights = MXNetEmitter.transpose(weights, dim)
    #         self.output_weights[IR_node.replace_scope(IR_node.name) + "_weight"] = weights

    #     code = ""
    #     if not defuse_pad:
    #         code = "{:<15} = mx.sym.Deconvolution(data = {}, kernel = ({}), stride = ({}), dilate = ({}), pad = ({}), num_filter = {}, no_bias = {}, layout = '{}', name = '{}')".format(
    #                 IR_node.replace_scope(IR_node.name),
    #                 IR_node.replace_scope(IR_node.in_edges[0]),
    #                 kernel,
    #                 stride,
    #                 dilate,
    #                 pad,
    #                 num_filter,
    #                 no_bias,
    #                 layout,
    #                 IR_node.replace_scope(IR_node.name))
    #     else:
    #         code = self.set_pad(IR_node, code, pad)
    #         code += "\n    {:<15} = mx.sym.Deconvolution(data = {}, kernel = ({}), stride = ({}), dilate = ({}), num_filter = {}, no_bias = {}, layout = '{}', name = '{}')".format(
    #                 IR_node.replace_scope(IR_node.name), IR_node.replace_scope(IR_node.name) + "_pad", kernel, stride, dilate, num_filter, no_bias, layout, IR_node.replace_scope(IR_node.name))

    #     return code


    def emit_Embedding(self, IR_node):

        input_dim = IR_node.IR_layer.attr["input_dim"].i
        output_dim = IR_node.IR_layer.attr["output_dim"].i
        dtype = MXNetEmitter.dtype_map.get(IR_node.layer.attr["dtype"].type, "float32")

        code = "{:<15} = mx.sym.Embedding(data = {}, input_dim = {}, output_dim = {}, dtype = {}, name = '{}')".format(
                IR_node.variable_name,
                self.parent_variable_name(IR_node),
                input_dim,
                output_dim,
                dtype,
                IR_node.name)

        return code


    def emit_LeakyRelu(self, IR_node):
        alpha = IR_node.IR_layer.attr['alpha'].f
        code = "{:<15} = mx.sym.LeakyReLU(data = {}, slope = {}, name = '{}')".format(
                IR_node.variable_name,
                self.parent_variable_name(IR_node),
                alpha,
                IR_node.name
        )
        return code

    def emit_Elu(self, IR_node):
        alpha = IR_node.IR_layer.attr['alpha'].f
        code = "{:<15} = mx.sym.LeakyReLU(data = {}, slope = {}, act_type = {}, name = '{}')".format(
                IR_node.variable_name,
                self.parent_variable_name(IR_node),
                alpha,
                'elu',
                IR_node.name
        )
        return code

    def emit_Dropout(self, IR_node):
        p = IR_node.IR_layer.attr["keep_prob"].f
        mode = IR_node.IR_layer.attr["mode"].s.lower().decode() if 'mode' in IR_node.layer.attr else 'training'
        code = "{:<15} = mx.sym.Dropout(data = {}, p = {}, mode = '{}', name = '{}')".format(
                IR_node.variable_name,
                self.parent_variable_name(IR_node),
                p,
                mode,
                IR_node.name)

        return code


    # reverse cannot support yet
    def emit_Reshape(self, IR_node):
        shape = list()
        for e in IR_node.IR_layer.attr["shape"].list.i:
            shape.append(e)
        shape = ', '.join('%s' % i for i in shape)
        reverse = False

        code = "{:<15} = mx.sym.reshape(data = {}, shape = ({}), reverse = {}, name = '{}')".format(
                IR_node.variable_name,
                self.parent_variable_name(IR_node),
                shape,
                reverse,
                IR_node.name)

        return code


    def emit_Flatten(self, IR_node):
        # code = "{:<15} = mx.sym.transpose(data = {}, axes = (0, 2, 3, 1))\n".format("trans", self.parent_variable_name(IR_node))
        code = "{:<15} = mx.sym.flatten(data = {}, name = '{}')".format(
                IR_node.variable_name,
                self.parent_variable_name(IR_node),
                IR_node.name)

        return code


    @staticmethod
    def _convert_axis(IR_node, axis):
        ndim = len(IR_node.layer.attr['_output_shapes'].list.shape[0].dim)
        if axis == 0:
            return 0
        elif axis == ndim - 1:
            return 1
        else:
            return axis + 1


    def emit_Concat(self, IR_node):
        dim = MXNetEmitter._convert_axis(IR_node, IR_node.IR_layer.attr["axis"].i)
        code = "{:<15} = mx.sym.concat({}, dim = {}, name = '{}')".format(
                IR_node.variable_name,
                ', '.join(self.IR_graph.get_node(s).real_variable_name for s in IR_node.in_edges),
                dim,
                IR_node.name)

        return code


    def emit_Cast(self, IR_node):

        dtype = IR_node.IR_layer.attr["dtype"].type

        code = "{:<15} = mx.sym.cast(data = {}, dtype = {}, name = '{}')".format(
                IR_node.variable_name,
                self.parent_variable_name(IR_node),
                dtype,
                IR_node.name)

        return code


    def emit_Expand_dims(self, IR_node):

        axis = IR_node.IR_layer.attr["axis"].i

        code = "{:<15} = mx.sym.expand_dims(data = {}, axis = {}, name = '{}')".format(
                IR_node.variable_name,
                self.parent_variable_name(IR_node),
                axis,
                IR_node.name)

        return code


    def emit_Pad(self, IR_node):
        mode = IR_node.IR_layer.attr["mode"].s.lower().decode()
        pad_width = list()
        pad_width.extend([0]*4)
        padding = convert_onnx_pad_to_tf(IR_node.get_attr("pads"))[1:-1]
        for padding_pair in padding:
            pad_width.extend(padding_pair)

        pad_width = ', '.join('%s' % i for i in pad_width)

        code = "{:<15} = mx.sym.pad(data = {}, mode = '{}', pad_width = ({}), name = '{}')".format(
                IR_node.variable_name,
                self.parent_variable_name(IR_node),
                mode,
                pad_width,
                IR_node.name)

        return code


    def emit_Add(self, IR_node):
        code = "{:<15} = mx.sym.broadcast_add({}, {})".format(
                IR_node.variable_name,
                self.parent_variable_name(IR_node),
                self.parent_variable_name(IR_node, [1]))

        return code


    def emit_Mul(self, IR_node):

        code = "{:<15} = mx.sym.broadcast_mul({}, {})".format(
                IR_node.variable_name,
                self.parent_variable_name(IR_node),
                self.parent_variable_name(IR_node, [1]))

        return code


    def emit_ReduceMean(self, IR_node):
        axes = IR_node.layer.attr['axes'].list.i[:]
        axes = ','.join('%s' % MXNetEmitter.transpose_map[i] for i in axes)

        code = "{:<15} = mx.sym.mean(data = {}, axis = ({}), keepdims = {})".format(
                IR_node.variable_name,
                self.parent_variable_name(IR_node),
                axes,
                IR_node.layer.attr['keepdims'].b)

        return code


    def emit_LRN(self, IR_node):
        code = "{:<15} = mx.sym.LRN(data = {}, alpha = {}, beta = {}, knorm = {}, nsize = {}, name = '{}')".format(
                IR_node.variable_name,
                self.parent_variable_name(IR_node),
                IR_node.layer.attr['alpha'].f,
                IR_node.layer.attr['beta'].f,
                IR_node.layer.attr['k'].f,
                IR_node.layer.attr['size'].i * 2 - 1,
                IR_node.name)

        return code

    def emit_Constant(self, IR_node):
        raise NotImplementedError()
        code = "{:<15} = mx.sym.identity(name='{}')".format(IR_node.variable_name, IR_node.name)
        self.output_weights[IR_node.name + '_data'] = self.weights[IR_node.name]['value']
        return code

    def emit_Sub(self, IR_node):
        code = "{:<15} = mx.sym.broadcast_sub({}, {})".format(
                IR_node.variable_name,
                self.parent_variable_name(IR_node),
                self.parent_variable_name(IR_node, [1]))

        return code


    def emit_Relu6(self, IR_node):
        self.add_body(1, self.emit_Activation(IR_node, 'relu'))
        old_name = IR_node.variable_name
        IR_node.real_name = IR_node.real_name + "_clip"
        self.add_body(1, "{:<15} = mx.sym.clip({}, a_min=0, a_max=6, name='{}')".format(
            IR_node.real_variable_name,
            old_name,
            IR_node.real_name))

        return ""


    # def emit_Slice(self, IR_node):
    #     starts = IR_node.get_attr('starts')
    #     starts = [starts[0], starts[-1]] + starts[1:-1]
    #     ends = IR_node.get_attr('ends')
    #     ends = [ends[0], ends[-1]] + ends[1:-1]
    #     ends = [i if i else None for i in ends]
    #     strides = IR_node.get_attr('strides')
    #     if strides:
    #         strides = [strides[0], strides[-1]] + strides[1:-1]

    #     self.add_body(1, "{:<15} = mx.sym.slice({}, begin={}, end={}, step={}, name='{}')".format(
    #         IR_node.real_variable_name,
    #         self.parent_variable_name(IR_node),
    #         starts,
    #         ends,
    #         strides,
    #         IR_node.name
    #     ))
    #     return ""

    def emit_Slice(self, IR_node):
        pass
    def emit_Const(self, IR_node):
        pass

    def emit_Shape(self, IR_node):
        pass
    def emit_Pack(self, IR_node):
        pass
Exemplo n.º 2
0
class CaffeEmitter(Emitter):
    def __init__(self, model):
        from six import string_types as _string_types
        super(CaffeEmitter, self).__init__()
        if isinstance(model, _string_types):
            network_path = model
        else:
            network_path = model[0]
            self._load_weights(model[1])

        self.IR_graph = IRGraph(network_path)
        super(CaffeEmitter, self)._build()

    @property
    def header_code(self):
        return """from __future__ import print_function
import numpy as np
import sys, argparse
import caffe
from caffe import layers as L
from caffe import params as P
from caffe import to_proto
from six import text_type as _text_type


__weights_dict = dict()

def load_weights(weight_file):
    if weight_file == None:
        return

    try:
        weights_dict = np.load(weight_file).item()
    except:
        weights_dict = np.load(weight_file, encoding='bytes').item()

    return weights_dict


def KitModel(weight_file = None):
    n = caffe.NetSpec()
"""

    @property
    def end_code(self):
        return """    return n

def make_net(prototxt):
    n = KitModel()
    with open(prototxt, 'w') as fpb:
        print(n.to_proto(), file=fpb)

def gen_weight(weight_file, model, prototxt):
    global __weights_dict
    __weights_dict = load_weights(weight_file)

    net = caffe.Net(prototxt, caffe.TRAIN)

    for key in __weights_dict:
        if 'weights' in __weights_dict[key]:
            net.params[key][0].data.flat = __weights_dict[key]['weights']
        elif 'mean' in __weights_dict[key]:
            net.params[key][0].data.flat = __weights_dict[key]['mean']
            net.params[key][1].data.flat = __weights_dict[key]['var']
            if 'scale' in __weights_dict[key]:
                net.params[key][2].data.flat = __weights_dict[key]['scale']
        elif 'scale' in __weights_dict[key]:
            net.params[key][0].data.flat = __weights_dict[key]['scale']
        if 'bias' in __weights_dict[key]:
            net.params[key][1].data.flat = __weights_dict[key]['bias']
        if 'gamma' in __weights_dict[key]: # used for prelu, not sure if other layers use this too
            net.params[key][0].data.flat = __weights_dict[key]['gamma']
    net.save(model)
    return net



if __name__=='__main__':
    parser = argparse.ArgumentParser(description='Generate caffe model and prototxt')
    parser.add_argument('--weight_file', '-w', type=_text_type, default='IR weight file')
    parser.add_argument('--prototxt', '-p', type=_text_type, default='caffe_converted.prototxt')
    parser.add_argument('--model', '-m', type=_text_type, default='caffe_converted.caffemodel')
    args = parser.parse_args()
    # For some reason argparser gives us unicode, so we need to conver to str first
    make_net(str(args.prototxt))
    gen_weight(str(args.weight_file), str(args.model), str(args.prototxt))

"""

    def gen_code(self, phase='test'):
        self.phase = phase
        self.add_body(0, self.header_code)

        #for test
        # with open("graph.txt", 'w') as f:
        #     for layer in self.IR_graph.topological_sort:
        #         current_node = self.IR_graph.get_node(layer)
        #         print("========current_node=========\n{}".format(current_node.layer), file=f)
        #test end

        for layer in self.IR_graph.topological_sort:
            current_node = self.IR_graph.get_node(layer)
            node_type = current_node.type
            #print("========current_node={}".format(current_node.layer))

            if hasattr(self, "emit_" + node_type):
                func = getattr(self, "emit_" + node_type)
                func(current_node)
            else:
                print("CaffeEmitter has not supported operator [%s]." %
                      (node_type))
                self.emit_UNKNOWN(current_node)

        self.add_body(0, "")
        self.add_body(0, self.end_code)

        return self.body_code

    def run(self, dstNetworkPath, dstWeightPath=None, phase='test'):
        super(CaffeEmitter, self).run(dstNetworkPath, dstWeightPath, phase)
        if self.weight_loaded:
            self.save_weights(self.weights_dict, dstWeightPath)

    @staticmethod
    def _shapeToStr(shapes):
        return [dim.size if dim.size > 0 else 1 for dim in shapes.dim]

    def _get_symmetric_padding(self, IR_node):
        stride_h = IR_node.get_attr('strides')[1]
        stride_w = IR_node.get_attr('strides')[2]

        # check if have pad layer
        IR_parent_node = self.IR_graph.get_parent(IR_node.name, [0])
        if IR_parent_node.type == 'Pad':
            pads = IR_parent_node.get_attr('pads')
        else:
            pads = IR_node.get_attr('pads')
        pad_h = pads[1] + (0 if pads[1] == pads[5] else stride_h)
        pad_w = pads[2] + (0 if pads[2] == pads[6] else stride_w)
        return pad_h, pad_w

    def check_if_need_transpose(self, IR_node):
        parent = self.IR_graph.get_parent(IR_node.name, [0])
        while parent.type == 'Flatten' or parent.type == 'Dropout' or parent.type == 'Reshape':
            parent = self.IR_graph.get_parent(parent.name, [0])
        dim = len(parent.layer.attr['_output_shapes'].list.shape[0].dim)
        if dim > 2:
            original_dims = self.weights_dict[IR_node.name]['weights'].shape
            dims = [
                i.size for i in
                parent.layer.attr['_output_shapes'].list.shape[0].dim[1:]
            ] + [-1]
            self.weights_dict[IR_node.name]['weights'] = np.reshape(
                self.weights_dict[IR_node.name]['weights'], dims)
            self.weights_dict[IR_node.name]['weights'] = np.transpose(
                self.weights_dict[IR_node.name]['weights'],
                [dim - 2] + list(range(0, dim - 2)) + [dim - 1])
            self.weights_dict[IR_node.name]['weights'] = np.reshape(
                self.weights_dict[IR_node.name]['weights'], original_dims)

    def emit_Conv(self, IR_node):
        # implement asymmetric paddings by applying symmetric padding then cropping
        pad_h, pad_w = self._get_symmetric_padding(IR_node)

        num_output = IR_node.get_attr('kernel_shape')[-1]
        if IR_node.type == "DepthwiseConv":
            num_group = IR_node.get_attr("kernel_shape")[-2]
            num_output = num_group * num_output
        else:
            num_group = IR_node.get_attr("group", 1)

        self.add_body(
            1,
            "n.{:<15} = L.Convolution(n.{}, kernel_h={}, kernel_w={}, stride={}, num_output={}, pad_h={}, pad_w={}, group={}, \
            bias_term={}, ntop=1)".format(IR_node.variable_name,
                                          self.parent_variable_name(IR_node),
                                          IR_node.get_attr('kernel_shape')[0],
                                          IR_node.get_attr('kernel_shape')[1],
                                          IR_node.get_attr('strides')[1],
                                          num_output, pad_h, pad_w, num_group,
                                          IR_node.get_attr('use_bias', False)))

        dim = len(IR_node.get_attr('strides')) - 2
        if self.weight_loaded:
            if IR_node.type == "DepthwiseConv":
                self.weights_dict[IR_node.name]['weights'] = np.swapaxes(
                    self.weights_dict[IR_node.name]['weights'], -1, -2)
            self.weights_dict[IR_node.name]['weights'] = np.transpose(
                self.weights_dict[IR_node.name]['weights'],
                [dim + 1, dim] + list(range(0, dim)))
            self.weights_dict[IR_node.variable_name] = self.weights_dict.pop(
                IR_node.name)

        self.check_if_need_crop(IR_node)
        # keys = []
        # for key in self.weights_dict[IR_node.name].keys():
        #     keys.append(key)
        # print("=======Layer: {}, keys: {}".format(IR_node.name, keys))

    def compute_output_shape(self, IR_node, kernel_h, kernel_w):
        parent_node = self.IR_graph.get_parent(IR_node.name, [0])

        if parent_node.get_attr('_output_shapes'):
            shape = parent_node.get_attr('_output_shapes')[0]
            shape = shape_to_list(shape)
            h_i = shape[1]
            w_i = shape[2]
            pad_h, pad_w = self._get_symmetric_padding(IR_node)
            stride_h = IR_node.get_attr('strides')[1]
            stride_w = IR_node.get_attr('strides')[2]

            if IR_node.type == 'Pool':
                h_o = (h_i + 2 * pad_h - kernel_h + stride_h -
                       1) // stride_h + 1
                w_o = (w_i + 2 * pad_w - kernel_w + stride_w -
                       1) // stride_w + 1
            else:
                h_o = (h_i + 2 * pad_h - kernel_h) // stride_h + 1
                w_o = (w_i + 2 * pad_w - kernel_w) // stride_w + 1
            return h_o, w_o
        else:
            assert False

    def check_if_need_crop(self, IR_node):
        shape = IR_node.get_attr('_output_shapes')[0]
        shape = shape_to_list(shape)
        ir_ho = shape[1]
        ir_wo = shape[2]
        if ir_ho < 0 or ir_wo < 0:
            return
        if IR_node.type == 'Pool':
            k_h = IR_node.get_attr('kernel_shape')[1]
            k_w = IR_node.get_attr('kernel_shape')[2]
        else:
            k_h = IR_node.get_attr('kernel_shape')[0]
            k_w = IR_node.get_attr('kernel_shape')[1]

        caffe_ho, caffe_wo = self.compute_output_shape(IR_node, k_h, k_w)

        # if asymmetric padding, set offset to 1
        pads = IR_node.get_attr('pads')
        offset = [
            0 if pads[1] == pads[5] else 1, 0 if pads[2] == pads[6] else 1
        ]
        if caffe_ho > ir_ho or caffe_wo > ir_wo:
            crop_layer_variable_name = IR_node.variable_name + "_crop"
            self.add_body(
                1,
                "n.{:<15} = L.Crop(n.{}, L.DummyData(shape=[dict(dim=[1, {}, {}, {}])], \
                ntop=1), ntop=1, offset={})".format(crop_layer_variable_name,
                                                    IR_node.variable_name,
                                                    shape[3], ir_ho, ir_wo,
                                                    offset))
            # Change the layer name
            IR_node.real_name = IR_node.real_name + "_crop"

    def emit_Pool(self, IR_node):
        pooling_type = IR_node.get_attr('pooling_type')
        if pooling_type == 'MAX':
            pooling_type = P.Pooling.MAX
        elif pooling_type == 'AVG':
            pooling_type = P.Pooling.AVE
        elif pooling_type == 'STOCHASTIC':
            pooling_type = P.Pooling.STOCHASTIC
        else:
            raise ValueError()

        if IR_node.layer.attr['global_pooling'].b:
            self.add_body(
                1,
                "n.{:<15} = L.Pooling(n.{}, pool={}, stride={}, global_pooling=True, ntop=1)"
                .format(IR_node.variable_name,
                        self.parent_variable_name(IR_node), pooling_type,
                        IR_node.get_attr('strides')[1]))
        else:
            pad_h, pad_w = self._get_symmetric_padding(IR_node)
            self.add_body(
                1,
                "n.{:<15} = L.Pooling(n.{}, pool={}, kernel_size={}, pad_h={}, pad_w={}, stride={}, ntop=1)"
                .format(IR_node.variable_name,
                        self.parent_variable_name(IR_node), pooling_type,
                        IR_node.get_attr('kernel_shape')[1], pad_h, pad_w,
                        IR_node.get_attr('strides')[1]))

            # check if need crop output shape
            self.check_if_need_crop(IR_node)

    def emit_ResizeBilinear(self, IR_node):
        shape = IR_node.get_attr("_output_shapes")[0]
        shape = shape_to_list(shape)
        self.add_body(
            1,
            "n.{:<15} = L.ResizeBilinear(n.{}, height={}, width={}, ntop=1)".
            format(IR_node.variable_name, self.parent_variable_name(IR_node),
                   shape[1], shape[2]))

    def emit_UNKNOWN(self, IR_node):
        print(IR_node.IR_layer.name)

    def emit_DataInput(self, IR_node):
        shape = self._shapeToStr(IR_node.get_attr('shape'))
        shape = [shape[0], shape[-1]] + shape[1:-1]
        self.add_body(
            1, "n.{:<15} = L.Input(shape=[dict(dim={})], ntop=1)".format(
                IR_node.variable_name, shape))

    def emit_Dropout(self, IR_node):
        in_place = True
        self.add_body(
            1,
            "n.{:<15} = L.Dropout(n.{}, dropout_ratio={} , in_place={}, ntop=1)"
            .format(IR_node.variable_name, self.parent_variable_name(IR_node),
                    1 - IR_node.get_attr('keep_prob'), in_place))

    def emit_FullyConnected(self, IR_node):
        self.add_body(
            1,
            "n.{:<15} = L.InnerProduct(n.{}, num_output={}, bias_term={}, ntop=1)"
            .format(IR_node.variable_name, self.parent_variable_name(IR_node),
                    IR_node.layer.attr["units"].i,
                    IR_node.get_attr('use_bias', False)))
        if self.weight_loaded:
            self.check_if_need_transpose(IR_node)
            self.weights_dict[IR_node.name]['weights'] = np.transpose(
                self.weights_dict[IR_node.name]['weights'], (1, 0))
            self.weights_dict[IR_node.variable_name] = self.weights_dict.pop(
                IR_node.name)

    def emit_BatchNorm(self, IR_node):

        self.add_body(
            1,
            "n.{:<15} = L.BatchNorm(n.{}, eps={}, use_global_stats={}, ntop=1)"
            .format(IR_node.variable_name, self.parent_variable_name(IR_node),
                    IR_node.get_attr('epsilon'), self.phase == 'test'))

        scale_layer_var_name = IR_node.variable_name + "_scale"
        self.add_body(
            1, "n.{:<15} = L.Scale(n.{}, bias_term={}, in_place=True, ntop=1)".
            format(scale_layer_var_name, IR_node.variable_name,
                   IR_node.get_attr('bias', False)))

        if self.weight_loaded:
            self.weights_dict[scale_layer_var_name] = dict()
            if 'scale' in self.weights_dict[IR_node.name]:
                self.weights_dict[scale_layer_var_name][
                    'scale'] = self.weights_dict[IR_node.name]['scale']
            else:
                self.weights_dict[scale_layer_var_name]['scale'] = 1

            self.weights_dict[IR_node.name]['scale'] = 1

            if 'bias' in self.weights_dict[IR_node.name]:
                self.weights_dict[scale_layer_var_name][
                    'bias'] = self.weights_dict[IR_node.name]['bias']
                self.weights_dict[IR_node.name].pop('bias', None)
                # change the key "name" to "variable_name", in case of the layer name has invalid characters

            self.weights_dict[IR_node.variable_name] = self.weights_dict.pop(
                IR_node.name)

        IR_node.real_name = IR_node.name + "_scale"

    def emit_Scale(self, IR_node):
        self.add_body(
            1, "n.{:<15} = L.Scale(n.{}, bias_term={}, in_place=True, ntop=1)".
            format(IR_node.variable_name, self.parent_variable_name(IR_node),
                   IR_node.get_attr('use_bias', False)))
        if self.weight_loaded:
            self.weights_dict[IR_node.variable_name] = self.weights_dict.pop(
                IR_node.name)

    def emit_Constant(self, IR_node):
        IR_node_after = self.IR_graph.get_son(IR_node.name, [0])
        shape = IR_node_after.get_attr("_output_shapes")[0]
        shape = shape_to_list(shape)
        self.add_body(
            1,
            "n.{:<15} = L.DummyData(shape=[dict(dim=[1,{},{},{}])], data_filler=dict(type='constant', value={}), ntop=1)"
            .format(IR_node.variable_name, shape[-1], shape[1], shape[2],
                    self.weights_dict[IR_node.name]['value'][0]))

    def emit_LRN(self, IR_node):
        self.add_body(
            1,
            "n.{:<15} = L.LRN(n.{}, local_size={}, alpha={}, beta={}, k={})".
            format(IR_node.variable_name, self.parent_variable_name(IR_node),
                   IR_node.get_attr('size') * 2 - 1, IR_node.get_attr('alpha'),
                   IR_node.get_attr('beta'), IR_node.get_attr('k')))

    def emit_Add(self, IR_node):
        input_layers = ', '.join(
            ('n.' +
             self.IR_graph.get_parent(IR_node.name, [num]).real_variable_name)
            for num in range(0, len(IR_node.in_edges)))
        self.add_body(
            1, "n.{:<15} = L.Eltwise({}, operation=1, ntop=1)".format(
                IR_node.variable_name,
                input_layers,
            ))

    def emit_Flatten(self, IR_node):
        IR_node.real_name = self.IR_graph.get_parent(IR_node.name,
                                                     [0]).real_name

    def emit_Squeeze(self, IR_node):
        shape = IR_node.get_attr("_output_shapes")[0]
        shape = shape_to_list(shape)
        if shape:
            dim_str = "'dim': {}".format(shape)
            dim_str = " reshape_param={'shape': { " + dim_str + '} }'
            self.add_body(
                1, "n.{:<15} = L.Reshape(n.{}, {})".format(
                    IR_node.variable_name, self.parent_variable_name(IR_node),
                    dim_str))
        else:
            IR_node.real_name = self.IR_graph.get_parent(IR_node.name,
                                                         [0]).real_name

    def emit_Concat(self, IR_node):
        axis_array = (2, 3, 1, 0)
        axis = axis_array.index(IR_node.get_attr('axis'))
        input_layers = ', '.join(
            ('n.' + self.IR_graph.get_node(edge).real_variable_name)
            for edge in IR_node.in_edges)
        self.add_body(
            1,
            "n.{:<15} = L.Concat({}, axis={})".format(IR_node.variable_name,
                                                      input_layers, axis))

    def emit_Sigmoid(self, IR_node):
        self.add_body(
            1, "n.{:<15} = L.Sigmoid(n.{}, ntop=1)".format(
                IR_node.variable_name, self.parent_variable_name(IR_node)))

    def emit_Relu(self, IR_node):
        in_place = True
        self.add_body(
            1, "n.{:<15} = L.ReLU(n.{}, in_place={}, ntop=1)".format(
                IR_node.variable_name, self.parent_variable_name(IR_node),
                in_place))

    def emit_LeakyRelu(self, IR_node):
        in_place = True
        self.add_body(
            1,
            "n.{:<15} = L.ReLU(n.{}, in_place={}, negative_slope={}, ntop=1)".
            format(IR_node.variable_name, self.parent_variable_name(IR_node),
                   in_place, IR_node.IR_layer.attr['alpha'].f))

    def emit_PRelu(self, IR_node):
        in_place = True
        self.add_body(
            1, "n.{:<15} = L.PReLU(n.{}, in_place={}, ntop=1)".format(
                IR_node.variable_name, self.parent_variable_name(IR_node),
                in_place))

    def emit_Tanh(self, IR_node):
        self.add_body(
            1, "n.{:<15} = L.TanH(n.{}, ntop=1)".format(
                IR_node.variable_name, self.parent_variable_name(IR_node)))

    def emit_Softmax(self, IR_node):
        self.add_body(
            1, "n.{:<15} = L.Softmax(n.{}, ntop=1)".format(
                IR_node.variable_name, self.parent_variable_name(IR_node)))

    def emit_Pad(self, IR_node):
        IR_node.real_name = self.IR_graph.get_parent(IR_node.name,
                                                     [0]).real_name

    def reduction(self, IR_node, op, axes):
        # Convert NHWC (IR) to NCHW (Caffe): [0,1,2,3]->[0,3,1,2]
        if len(axes) == 1:
            assert (axes[0] == 2)
        elif len(axes) == 2:
            assert ((axes[0] == 1) and (axes[1] == 2))

        self.add_body(
            1, "n.{:<15} = L.Reduction(n.{}, operation={} , axis={} ,ntop=1)".
            format(IR_node.variable_name, self.parent_variable_name(IR_node),
                   op, len(axes)))

        if IR_node.get_attr('keepdims') == True:
            shape = IR_node.get_attr("_output_shapes")[0]
            shape = shape_to_list(shape)
            shape = [1] + [shape[-1]] + shape[1:-1]
            dim_str = "'dim': {}".format(shape)
            dim_str = "{'shape': { " + dim_str + '} }'
            self.add_body(
                1, "n.{:<15} = L.Reshape(n.{}, reshape_param={}) ".format(
                    IR_node.variable_name + "_reshape",
                    IR_node.real_variable_name, dim_str))
            IR_node.real_name = IR_node.real_name + '_reshape'

    def emit_ReduceMean(self, IR_node):
        self.reduction(IR_node, 4, IR_node.get_attr('axes'))

    def emit_ReduceSum(self, IR_node):
        self.reduction(IR_node, 1, IR_node.get_attr('axes'))

    def emit_Relu6(self, IR_node):
        self.emit_Relu(IR_node)

    def emit_DepthwiseConv(self, IR_node):
        self.emit_Conv(IR_node)

    def emit_Const(self, IR_node):
        pass

    def emit_Shape(self, IR_node):
        pass

    def emit_Reshape(self, IR_node):
        # currently for the flatten layer
        self.add_body(
            1, "n.{:<15} = L.Flatten(n.{})".format(
                IR_node.variable_name,
                self.parent_variable_name(IR_node),
            ))

    def emit_Slice(self, IR_node):
        pass

    def emit_Pack(self, IR_node):
        pass

    def emit_Abs(self, IR_node):
        self.add_body(
            1, "n.{:<15} = L.AbsVal(n.{}, ntop=1)".format(
                IR_node.variable_name, self.parent_variable_name(IR_node)))

    def emit_Sub(self, IR_node):
        input_layers = ', '.join(
            ('n.' + self.IR_graph.get_node(edge).real_variable_name)
            for edge in IR_node.in_edges)
        self.add_body(
            1, "n.{:<15} = L.Eltwise({}, coeff = [1, -1], ntop=1)".format(
                IR_node.variable_name, input_layers))

    def emit_Mul(self, IR_node):
        if len(IR_node.in_edges) == 2:
            input_layers = ', '.join(
                ('n.' + self.IR_graph.get_node(edge).real_variable_name)
                for edge in IR_node.in_edges)
            self.add_body(
                1, "n.{:<15} = L.Eltwise({}, operation=0, ntop=1)".format(
                    IR_node.variable_name, input_layers))
        elif len(IR_node.in_edges) == 1:
            self.emit_Scale(IR_node)
        else:
            assert False

    def emit_UpSampling2D(self, IR_node):
        scales = IR_node.get_attr('scales')
        scale = tuple(scales)[0]

        shape = IR_node.get_attr('_output_shapes')[0]
        shape = shape_to_list(shape)

        self.add_body(
            1,
            "n.{:<15} = L.Deconvolution(n.{}, convolution_param=dict(kernel_size={}, stride={}, pad={}, num_output={}, group={}, bias_term={}), param=[dict(lr_mult=0)], ntop=1)"
            .format(IR_node.variable_name, IR_node.in_edges[0],
                    2 * scale - scale % 2, scale,
                    int(math.ceil(
                        (scale - 1) / 2)), shape[-1], shape[-1], False))