def get_edge_mask(op: Union[Conv2dOp, PoolOp], compact: bool) -> np.ndarray: edge_shape = op.attrs[TraceKey.EDGE_SHAPE] if not compact: return TraceKey.to_mask(op.attrs[TraceKey.EDGE], (np.prod(edge_shape[:3], edge_shape[3:])), compact) input_tensor: Tensor = op.input_nodes[0] output_tensor: Tensor = op.output_nodes[0] edge = TraceKey.to_array(op.attrs[TraceKey.EDGE], compact) input_shape = input_tensor.attrs[TraceKey.POINT_SHAPE] output_shape = output_tensor.attrs[TraceKey.POINT_SHAPE] if op.data_format == "NHWC": input_shape = (input_shape[2], input_shape[0], input_shape[1]) output_shape = (output_shape[2], output_shape[0], output_shape[1]) if isinstance(op, Conv2dOp): in_channel, kernel_height, kernel_width, out_channel, out_height, out_width = np.unravel_index( edge, edge_shape) else: kernel_height, kernel_width, out_channel, out_height, out_width = np.unravel_index( edge, edge_shape) stride = np.array(op.strides) kernel_size = (np.array(op.attrs[TraceKey.WEIGHT_SHAPE])[2:] if isinstance( op, Conv2dOp) else np.array(op.filter_shape)) padding = calc_padding( np.array(input_shape)[1:], np.array(output_shape)[1:], stride, kernel_size) in_height = kernel_height + (out_height * stride[0]) - padding[1][0] in_width = kernel_width + (out_width * stride[1]) - padding[2][0] edge_output_index = np.ravel_multi_index( (out_channel, out_height, out_width), output_shape) if isinstance(op, Conv2dOp): edge_input_index = np.ravel_multi_index( (in_channel, in_height, in_width), input_shape) else: edge_input_index = np.ravel_multi_index( (out_channel, in_height, in_width), input_shape) mask = np.zeros((np.prod(input_shape), np.prod(output_shape)), dtype=np.int8) mask[(edge_input_index, edge_output_index)] = 1 return mask
def linear_layer_trace(op: DenseOp, compact: bool, *args, **kwargs): edge_mask = TraceKey.to_mask(op.attrs[TraceKey.EDGE], op.attrs[TraceKey.EDGE_SHAPE], compact) set_input_path(op, edge_mask)