def visit_call(self, call): """ Visit the children. """ # First visit the children. oshape = _get_tensor_shape(call) odtype = _get_tensor_type(call) input_types = [arg.checked_type for arg in call.args] args = [self.visit(arg) for arg in call.args] # Start and stop cases. if call.op == self.bitpack_start: assert not self.start_pack self.start_pack = True return _pack_batch_channel(args[0], oshape, self.bfactor, self.cfactor) if call.op == self.bitpack_end: if self.start_pack: self.start_pack = False data = args[0] data_shape = _get_tensor_shape(call.args[0]) return _unpack_batch_channel(data, data_shape) if self.start_pack: # Operator cases if call.op == self.conv2d and odtype == 'int32': self.number_of_conv2d += 1 assert 8 % self.weight_bits == 0 w_lanes = 8 // self.weight_bits data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor) kernel_layout = "OIHW%do%di" % (self.cfactor, self.cfactor) data, weight = args data_shape = _to_shape(input_types[0].shape) kernel_shape = _to_shape(input_types[1].shape) channels = call.attrs.channels weight, kernel_shape, channels = _weight_shape_match( weight, kernel_shape, channels, self.cfactor) kernel = _pack_weight(weight, kernel_shape, self.cfactor) # insert bit packing when necessary if w_lanes != 1: assert 8 % w_lanes == 0 kernel = op.bitpack(kernel, lanes=w_lanes) conv2d = op.nn.conv2d(data, kernel, strides=call.attrs.strides, padding=call.attrs.padding, dilation=call.attrs.dilation, groups=call.attrs.groups, channels=channels, kernel_size=call.attrs.kernel_size, data_layout=data_layout, kernel_layout=kernel_layout, out_dtype=call.attrs.out_dtype) return conv2d if call.op == self.conv2d_transpose and odtype == 'int32': self.number_of_conv2d += 1 assert 8 % self.weight_bits == 0 w_lanes = 8 // self.weight_bits if self.start_pack: data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor) kernel_layout = "IOHW%di%do" % (self.cfactor, self.cfactor) data, weight = args data_shape = _to_shape(input_types[0].shape) kernel_shape = _to_shape(input_types[1].shape) channels = call.attrs.channels weight, kernel_shape, channels = _weight_shape_match_transpose( weight, kernel_shape, channels, self.cfactor) kernel = _pack_weight_conv2d_transpose( weight, kernel_shape, self.cfactor) conv2d = op.nn.conv2d_transpose( data, kernel, strides=call.attrs.strides, padding=call.attrs.padding, dilation=call.attrs.dilation, groups=call.attrs.groups, channels=call.attrs.channels, kernel_size=call.attrs.kernel_size, data_layout=data_layout, kernel_layout=kernel_layout, output_padding=call.attrs.output_padding, out_dtype=call.attrs.out_dtype) return conv2d if call.op == self.add and \ tuple(input_types[0].shape) == tuple(input_types[1].shape): pass elif call.op == self.add and len(input_types[1].shape) == 3: data, const = args const, input_shape = _const_shape_match( const, input_types[1].shape, self.cfactor) const = _pack_const(const, _to_shape(input_shape), input_types[1].dtype, self.bfactor, self.cfactor) return relay.Call(self.add, [data, const]) elif call.op == self.multiply and \ tuple(input_types[0].shape) == tuple(input_types[1].shape): pass elif call.op == self.multiply and len(input_types[1].shape) == 3: data, const = args const = _pack_const(const, _to_shape(input_types[1].shape), input_types[1].dtype, self.bfactor, self.cfactor) return relay.Call(self.multiply, [data, const]) elif self.start_pack and call.op == self.bias_add: data, bias = args bias = _pack_const(bias, _to_shape(input_types[1].shape), input_types[1].dtype, self.bfactor, self.cfactor) return relay.Call(self.add, [data, bias]) elif self.start_pack and call.op == op.op.get('cast') and \ input_types[0].dtype == 'int32': cast = relay.Call(op.op.get('cast'), [args[0]], call.attrs) return relay.Call(op.op.get('copy'), [cast]) elif call.op == self.pad: pad_width = call.attrs.pad_width if len(pad_width) == 6: pass elif len(pad_width) == 4: data, = args new_pad_width = [] new_pad_width.extend(pad_width) for _ in range(2): new_pad_width.append([0, 0]) return op.nn.pad(data, pad_value=call.attrs.pad_value, pad_width=new_pad_width) elif call.op == self.upsampling: data, = args scale_h = call.attrs.scale_h scale_w = call.attrs.scale_w data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor) method = call.attrs.method align_corners = call.attrs.align_corners return op.nn.upsampling(data, scale_h, scale_w, data_layout, method, align_corners) elif call.op == self.reshape and len(input_types[0].shape) == 4: data, _ = args data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3)) return op.reshape(data, [int(x) for x in input_types[0].shape]) return relay.Call(self.visit(call.op), args, call.attrs)
def visit_call(self, call): # First visit the children. oshape = _get_shape(call) odtype = call.checked_type.dtype input_types = [arg.checked_type for arg in call.args] args = [self.visit(arg) for arg in call.args] # Start and stop cases. if call.op == self.bitpack_start: assert not self.start_pack self.start_pack = True return _pack_batch_channel(args[0], oshape, self.bfactor, self.cfactor) elif call.op == self.bitpack_end: if self.start_pack: self.start_pack = False data = args[0] data_shape = _get_shape(call.args[0]) return _unpack_batch_channel(data, data_shape) else: pass if self.start_pack: # Operator cases if call.op == self.conv2d and odtype == 'int32': self.number_of_conv2d += 1 assert 8 % self.weight_bits == 0 w_lanes = 8 // self.weight_bits data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor) kernel_layout = "OIHW%do%di" % (self.cfactor, self.cfactor) data, weight = args data_shape = _to_shape(input_types[0].shape) kernel_shape = _to_shape(input_types[1].shape) kernel = _pack_weight(weight, kernel_shape, self.cfactor) # insert bit packing when necessary if w_lanes != 1: assert 8 % w_lanes == 0 kernel = op.bitpack(kernel, lanes=w_lanes) conv2d = op.nn.conv2d( data, kernel, strides=call.attrs.strides, padding=call.attrs.padding, dilation=call.attrs.dilation, groups=call.attrs.groups, channels=call.attrs.channels, kernel_size=call.attrs.kernel_size, data_layout=data_layout, kernel_layout=kernel_layout, out_dtype=call.attrs.out_dtype) return conv2d elif call.op == self.conv2d_transpose and odtype == 'int32': self.number_of_conv2d += 1 assert 8 % self.weight_bits == 0 w_lanes = 8 // self.weight_bits if self.start_pack: data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor) kernel_layout = "IOHW%di%do" % (self.cfactor, self.cfactor) data, weight = args data_shape = _to_shape(input_types[0].shape) kernel_shape = _to_shape(input_types[1].shape) kernel = _pack_weight_conv2d_transpose(weight, kernel_shape, self.cfactor) conv2d = op.nn.conv2d_transpose( data, kernel, strides=call.attrs.strides, padding=call.attrs.padding, dilation=call.attrs.dilation, groups=call.attrs.groups, channels=call.attrs.channels, kernel_size=call.attrs.kernel_size, data_layout=data_layout, kernel_layout=kernel_layout, output_padding=call.attrs.output_padding, out_dtype=call.attrs.out_dtype) return conv2d elif call.op == self.add and tuple(input_types[0].shape) == tuple(input_types[1].shape): pass elif call.op == self.add and len(input_types[1].shape) == 3: data, bias = args bias = _pack_bias(bias, _to_shape(input_types[1].shape), input_types[1].dtype, self.bfactor, self.cfactor) return relay.Call(self.add, [data, bias]) elif self.start_pack and call.op == self.bias_add: data, bias = args bias = _pack_bias(bias, _to_shape(input_types[1].shape), input_types[1].dtype, self.bfactor, self.cfactor) return relay.Call(self.add, [data, bias]) elif self.start_pack and call.op == op.op.get('cast') and \ input_types[0].dtype == 'int32': cast = relay.Call(op.op.get('cast'), [args[0]], call.attrs) return relay.Call(op.op.get('copy'), [cast]) return relay.Call( self.visit(call.op), args, call.attrs)