Example #1
0
    def visit_call(self, call):
        """ Visit the children. """
        # First visit the children.
        oshape = _get_tensor_shape(call)
        odtype = _get_tensor_type(call)
        input_types = [arg.checked_type for arg in call.args]
        args = [self.visit(arg) for arg in call.args]

        # Start and stop cases.
        if call.op == self.bitpack_start:
            assert not self.start_pack
            self.start_pack = True
            return _pack_batch_channel(args[0], oshape, self.bfactor,
                                       self.cfactor)
        if call.op == self.bitpack_end:
            if self.start_pack:
                self.start_pack = False
                data = args[0]
                data_shape = _get_tensor_shape(call.args[0])
                return _unpack_batch_channel(data, data_shape)
        if self.start_pack:
            # Operator cases
            if call.op == self.conv2d and odtype == 'int32':
                self.number_of_conv2d += 1
                assert 8 % self.weight_bits == 0
                w_lanes = 8 // self.weight_bits
                data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor)
                kernel_layout = "OIHW%do%di" % (self.cfactor, self.cfactor)
                data, weight = args
                data_shape = _to_shape(input_types[0].shape)
                kernel_shape = _to_shape(input_types[1].shape)
                channels = call.attrs.channels
                weight, kernel_shape, channels = _weight_shape_match(
                    weight, kernel_shape, channels, self.cfactor)
                kernel = _pack_weight(weight, kernel_shape, self.cfactor)
                # insert bit packing when necessary
                if w_lanes != 1:
                    assert 8 % w_lanes == 0
                    kernel = op.bitpack(kernel, lanes=w_lanes)

                conv2d = op.nn.conv2d(data,
                                      kernel,
                                      strides=call.attrs.strides,
                                      padding=call.attrs.padding,
                                      dilation=call.attrs.dilation,
                                      groups=call.attrs.groups,
                                      channels=channels,
                                      kernel_size=call.attrs.kernel_size,
                                      data_layout=data_layout,
                                      kernel_layout=kernel_layout,
                                      out_dtype=call.attrs.out_dtype)
                return conv2d

            if call.op == self.conv2d_transpose and odtype == 'int32':
                self.number_of_conv2d += 1
                assert 8 % self.weight_bits == 0
                w_lanes = 8 // self.weight_bits
                if self.start_pack:
                    data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor)
                    kernel_layout = "IOHW%di%do" % (self.cfactor, self.cfactor)
                    data, weight = args
                    data_shape = _to_shape(input_types[0].shape)
                    kernel_shape = _to_shape(input_types[1].shape)
                    channels = call.attrs.channels
                    weight, kernel_shape, channels = _weight_shape_match_transpose(
                        weight, kernel_shape, channels, self.cfactor)
                    kernel = _pack_weight_conv2d_transpose(
                        weight, kernel_shape, self.cfactor)
                    conv2d = op.nn.conv2d_transpose(
                        data,
                        kernel,
                        strides=call.attrs.strides,
                        padding=call.attrs.padding,
                        dilation=call.attrs.dilation,
                        groups=call.attrs.groups,
                        channels=call.attrs.channels,
                        kernel_size=call.attrs.kernel_size,
                        data_layout=data_layout,
                        kernel_layout=kernel_layout,
                        output_padding=call.attrs.output_padding,
                        out_dtype=call.attrs.out_dtype)
                return conv2d
            if call.op == self.add and \
                    tuple(input_types[0].shape) == tuple(input_types[1].shape):
                pass
            elif call.op == self.add and len(input_types[1].shape) == 3:
                data, const = args
                const, input_shape = _const_shape_match(
                    const, input_types[1].shape, self.cfactor)
                const = _pack_const(const, _to_shape(input_shape),
                                    input_types[1].dtype, self.bfactor,
                                    self.cfactor)
                return relay.Call(self.add, [data, const])
            elif call.op == self.multiply and \
                    tuple(input_types[0].shape) == tuple(input_types[1].shape):
                pass
            elif call.op == self.multiply and len(input_types[1].shape) == 3:
                data, const = args
                const = _pack_const(const, _to_shape(input_types[1].shape),
                                    input_types[1].dtype, self.bfactor,
                                    self.cfactor)
                return relay.Call(self.multiply, [data, const])
            elif self.start_pack and call.op == self.bias_add:
                data, bias = args
                bias = _pack_const(bias, _to_shape(input_types[1].shape),
                                   input_types[1].dtype, self.bfactor,
                                   self.cfactor)
                return relay.Call(self.add, [data, bias])
            elif self.start_pack and call.op == op.op.get('cast') and \
                    input_types[0].dtype == 'int32':
                cast = relay.Call(op.op.get('cast'), [args[0]], call.attrs)
                return relay.Call(op.op.get('copy'), [cast])
            elif call.op == self.pad:
                pad_width = call.attrs.pad_width
                if len(pad_width) == 6:
                    pass
                elif len(pad_width) == 4:
                    data, = args
                    new_pad_width = []
                    new_pad_width.extend(pad_width)
                    for _ in range(2):
                        new_pad_width.append([0, 0])
                    return op.nn.pad(data,
                                     pad_value=call.attrs.pad_value,
                                     pad_width=new_pad_width)
            elif call.op == self.upsampling:
                data, = args
                scale_h = call.attrs.scale_h
                scale_w = call.attrs.scale_w
                data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor)
                method = call.attrs.method
                align_corners = call.attrs.align_corners
                return op.nn.upsampling(data, scale_h, scale_w, data_layout,
                                        method, align_corners)
            elif call.op == self.reshape and len(input_types[0].shape) == 4:
                data, _ = args
                data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3))
                return op.reshape(data, [int(x) for x in input_types[0].shape])

        return relay.Call(self.visit(call.op), args, call.attrs)
Example #2
0
    def visit_call(self, call):
        # First visit the children.
        oshape = _get_shape(call)
        odtype = call.checked_type.dtype
        input_types = [arg.checked_type for arg in call.args]
        args = [self.visit(arg) for arg in call.args]

        # Start and stop cases.
        if call.op == self.bitpack_start:
            assert not self.start_pack
            self.start_pack = True
            return _pack_batch_channel(args[0], oshape, self.bfactor, self.cfactor)
        elif call.op == self.bitpack_end:
            if self.start_pack:
                self.start_pack = False
                data = args[0]
                data_shape = _get_shape(call.args[0])
                return _unpack_batch_channel(data, data_shape)
            else:
                pass
        if self.start_pack:
            # Operator cases
            if call.op == self.conv2d and odtype == 'int32':
                self.number_of_conv2d += 1
                assert 8 % self.weight_bits == 0
                w_lanes = 8 // self.weight_bits
                data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor)
                kernel_layout = "OIHW%do%di" % (self.cfactor, self.cfactor)
                data, weight = args
                data_shape = _to_shape(input_types[0].shape)
                kernel_shape = _to_shape(input_types[1].shape)
                kernel = _pack_weight(weight, kernel_shape, self.cfactor)
                # insert bit packing when necessary
                if w_lanes != 1:
                    assert 8 % w_lanes == 0
                    kernel = op.bitpack(kernel, lanes=w_lanes)
                conv2d = op.nn.conv2d(
                    data,
                    kernel,
                    strides=call.attrs.strides,
                    padding=call.attrs.padding,
                    dilation=call.attrs.dilation,
                    groups=call.attrs.groups,
                    channels=call.attrs.channels,
                    kernel_size=call.attrs.kernel_size,
                    data_layout=data_layout,
                    kernel_layout=kernel_layout,
                    out_dtype=call.attrs.out_dtype)
                return conv2d
            elif call.op == self.conv2d_transpose and odtype == 'int32':
                self.number_of_conv2d += 1
                assert 8 % self.weight_bits == 0
                w_lanes = 8 // self.weight_bits
                if self.start_pack:
                    data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor)
                    kernel_layout = "IOHW%di%do" % (self.cfactor, self.cfactor)
                    data, weight = args
                    data_shape = _to_shape(input_types[0].shape)
                    kernel_shape = _to_shape(input_types[1].shape)
                    kernel = _pack_weight_conv2d_transpose(weight, kernel_shape, self.cfactor)
                    conv2d = op.nn.conv2d_transpose(
                        data,
                        kernel,
                        strides=call.attrs.strides,
                        padding=call.attrs.padding,
                        dilation=call.attrs.dilation,
                        groups=call.attrs.groups,
                        channels=call.attrs.channels,
                        kernel_size=call.attrs.kernel_size,
                        data_layout=data_layout,
                        kernel_layout=kernel_layout,
                        output_padding=call.attrs.output_padding,
                        out_dtype=call.attrs.out_dtype)
                return conv2d
            elif call.op == self.add and tuple(input_types[0].shape) == tuple(input_types[1].shape):
                pass
            elif call.op == self.add and len(input_types[1].shape) == 3:
                data, bias = args
                bias = _pack_bias(bias,
                                  _to_shape(input_types[1].shape),
                                  input_types[1].dtype,
                                  self.bfactor,
                                  self.cfactor)
                return relay.Call(self.add, [data, bias])
            elif self.start_pack and call.op == self.bias_add:
                data, bias = args
                bias = _pack_bias(bias,
                                  _to_shape(input_types[1].shape),
                                  input_types[1].dtype,
                                  self.bfactor,
                                  self.cfactor)
                return relay.Call(self.add, [data, bias])
            elif self.start_pack and call.op == op.op.get('cast') and \
                    input_types[0].dtype == 'int32':
                cast = relay.Call(op.op.get('cast'), [args[0]], call.attrs)
                return relay.Call(op.op.get('copy'), [cast])

        return relay.Call(
            self.visit(call.op),
            args,
            call.attrs)