def init_mkldnn_reorder(self, op): (mkl_layout, mkl_axes) = op.in_layout check_flatten = False for axis_indx, each_axis in enumerate(op.axes): if isinstance( each_axis, FlattenedAxis) and not (mkl_axes[axis_indx].is_flattened): check_flatten = True if check_flatten: mkl_axes_order = get_order_from_axes(unflatten(op).axes, mkl_axes) else: mkl_axes_order = get_order_from_axes(op.axes, mkl_axes) (out_layout, _) = get_mkl_layout(self.mkldnn, op, mkl_axes_order, True) ndims = len(mkl_axes) if check_flatten: dims = get_size_mkl_order(unflatten(op).axes, mkl_axes_order) else: dims = get_size_mkl_order(op.axes, mkl_axes_order) dims_arg = ((ct.c_int) * ndims)(*dims) op_id = len(self.mkldnn.kernels) self.mkldnn.kernels[op.name] = self.mkldnn.create_empty_kernel(op_id) self.mkldnn.reorder_kernel(self.mkldnn.mkldnn_engine, ndims, dims_arg, self.mkldnn.datatype[op.dtype.type], self.mkldnn.memory_format['blocked'], mkl_layout, out_layout, self.mkldnn.kernels[op.name]) dbg_print_kernel(self.mkldnn, op, op_id)
def visit(self, op, inputs, gamma, bias, epsilon, mean, variance): # unflatten the inputs and extract C H W N params if isinstance(inputs, Flatten): unflatten_inputs = unflatten(inputs) # Sanity check tensor shapes if (len(unflatten_inputs.axes.lengths) != 5): return # Only single precision float supported for now if op.dtype != np.float32: return data_type = self.mkldnn.datatype[op.dtype.type] inputs_shape = get_size_mkl_order(unflatten(inputs).axes, [4, 0, 2, 3]) mean_size = mean.axes.lengths[0] mean_dims = 1 gamma_shape = gamma.axes.lengths[0] bias_shape = bias.axes.lengths[0] variance_size = variance.axes.lengths[0] variance_dims = 1 outputs_shape = op.axes.lengths # weights is 2 dimensional, 1-st dimension contains gamma parameter, 2-nd # dimension contains beta parameter. weights_shape = [gamma_shape, bias_shape] weights_shape_arg = ((ct.c_int) * len(weights_shape))(*weights_shape) input_shape_arg = ((ct.c_int) * len(inputs_shape))(*inputs_shape) outputs_shape_arg = ((ct.c_int) * len(outputs_shape))(*outputs_shape) (inputs_layout, mkl_axes) = get_mkl_layout( self.mkldnn, unflatten_inputs, [4, 0, 2, 3], True) mean_layout = None variance_layout = None op_id = len(self.mkldnn.kernels) self.mkldnn.kernels[op.name] = self.mkldnn.create_empty_kernel(op_id) self.mkldnn.batchnorm_fprop_kernel( self.mkldnn.mkldnn_engine, len(inputs_shape), len(outputs_shape), len(weights_shape), mean_dims, variance_dims, mean_size, variance_size, input_shape_arg, weights_shape_arg, outputs_shape_arg, op.eps, inputs_layout, None, mean_layout, variance_layout, data_type, self.mkldnn.kernels[ op.name]) self.set_mkl_layout_data(op, mkl_axes) dbg_print_kernel(self.mkldnn, op, op_id)
def visit(self, op, arg): p_axis = arg.axes.find_by_name(op.metadata['parallel'].name) assert len(p_axis) > 0, "Invalid to scatter a scalar" if arg.axes.index(p_axis[0]) > 0: arg = axes_with_order(arg, p_axis + (arg.axes - p_axis)) arg = flatten_at(arg, 0) arg = unflatten(arg) # replace the ops new_op = op.copy_with_new_args([arg]) self.replace_op(op, new_op)
def visit(self, op): x, y = op.args x_reduction_axes = op.x_reduction_axes y_reduction_axes = op.y_reduction_axes out_axes = op.axes if len(x_reduction_axes) == 0: d = make_axis(1) x_reduction_axes = make_axes((d, )) y_reduction_axes = x_reduction_axes x = broadcast(x, x.axes + x_reduction_axes) y = broadcast(y, y_reduction_axes + y.axes) if x.is_scalar: temp = x x = y y = temp if y.is_scalar: if x.is_scalar: out = x.scalar_op * y.scalar_op if len(x_reduction_axes) > 0: out = out * x_reduction_axes.size out = broadcast(out, op.axes) else: out = Sum(x, x_reduction_axes) * y.scalar_op out = broadcast(out, op.axes) else: x_rem_axes = x.axes - x_reduction_axes x = axes_with_order(x, x_rem_axes + x_reduction_axes) y_rem_axes = y.axes - y_reduction_axes y = axes_with_order(y, y_reduction_axes + y_rem_axes) x = flatten_at(x, len(x.axes) - len(x_reduction_axes)) y = flatten_at(y, len(y_reduction_axes)) if len(out_axes) == 0: out = DotOneDimensional(x, y, axes=()) elif len(x.axes) == 1: y = Transpose(y) out = DotTwoByOne(y, x, axes=y.axes[0]) elif len(y.axes) == 1: out = DotTwoByOne(x, y, axes=x.axes[0]) else: out = DotTwoDimensional( x, y, axes=([op.x_out_axes.flatten(), op.y_out_axes.flatten()])) out = unflatten(out) out = ReorderAxes(out, out_axes) self.replace_op(op, out)
def visit(self, op): x, y = op.args reduction_axes = op.reduction_axes out_axes = op.axes if len(reduction_axes) == 0: # TODO: this is a weird case, should we really support it? d = make_axis(1) reduction_axes = make_axes((d, )) x = broadcast(x, x.axes + reduction_axes) y = broadcast(y, reduction_axes + y.axes) if x.is_scalar: x, y = y, x if y.is_scalar: if x.is_scalar: out = x.scalar_op * y.scalar_op if len(reduction_axes) > 0: out = out * reduction_axes.size out = broadcast(out, op.axes) else: out = Sum(x, reduction_axes) * y.scalar_op out = broadcast(out, op.axes) else: # move reduction_axes to end x = axes_with_order(x, (x.axes - reduction_axes) + reduction_axes) # move reduction axes to front y = axes_with_order(y, reduction_axes + (y.axes - reduction_axes)) # flatten non-reduction axes together and reduction axes together x = flatten_at(x, len(x.axes) - len(reduction_axes)) # flatten non-reduction axes together and reduction axes together y = flatten_at(y, len(reduction_axes)) if len(out_axes) == 0: out = DotLowDimension(x, y, axes=()) elif len(x.axes) == 1: y = Transpose(y) out = DotLowDimension(y, x, axes=y.axes[0]) elif len(y.axes) == 1: out = DotLowDimension(x, y, axes=x.axes[0]) else: out = DotLowDimension(x, y, axes=([ op.x_out_axes.flatten(True), op.y_out_axes.flatten(True) ])) out = unflatten(out) out = ReorderAxes(out, out_axes) self.replace_op(op, out)
def visit(self, op, arg): p_axis = arg.axes.find_by_name( op.send_node().metadata['parallel'].name) if len(p_axis) == 0: pass elif arg.axes.index(p_axis[0]) > 0: arg = axes_with_order(arg, p_axis + (arg.axes - p_axis)) arg = flatten_at(arg, 0) arg = unflatten(arg) # replace the ops new_op = op.copy_with_new_args([arg]) self.replace_op(op, new_op)
def visit(self, op, arg): if 'parallel' not in arg.metadata: return p_axis = op.axes.find_by_name(arg.metadata['parallel'].name) if len(p_axis) == 0: self.replace_op(op, arg) elif op.axes.index(p_axis[0]) > 0: arg = axes_with_order(arg, op.axes) arg = flatten_at(arg, 0) arg = unflatten(arg) # replace the ops self.replace_op(op, arg)
def get_axes_mkl_order(axes, order): axes_list = [] flattend_axis_flag = False for axis in axes: if (axis.is_flattened and len(order) > 2): unflattend_axis = unflatten(axis).axes for indx in range(len(unflattend_axis)): axes_list.append(unflattend_axis[indx]) flattend_axis_flag = True else: axes_list.append(axis) if flattend_axis_flag: return [axes_list[index] for index in order] else: return [axes[index] for index in order]
def visit(self, op, delta, fprop_src, gamma, bias, mean, variance): axis_len_5d = False # Only single precision float supported for now if op.dtype != np.float32: return # Sanity check tensor shapes if (len(op.axes.lengths) == 2): if isinstance(op.axes[1], FlattenedAxis): C, Flatten_axis = op.axes if (len(unflatten(Flatten_axis).axes.lengths) != 4): return else: outputs_shape = get_size_mkl_order(unflatten(op).axes, [4, 0, 2, 3]) outputs_shape_arg = ((ct.c_int) * len(outputs_shape))(*outputs_shape) axis_len_5d = True else: return data_type = self.mkldnn.datatype[op.dtype.type] mean_dims = 1 variance_dims = 1 mean_size = mean.axes.lengths[0] variance_size = variance.axes.lengths[0] delta_shape = get_size_mkl_order(unflatten(delta).axes, [4, 0, 2, 3]) delta_shape_arg = ((ct.c_int) * len(delta_shape))(*delta_shape) # weights is 2 dimensional, 1-st dimension contains gamma parameter, 2-nd # dimension contains beta parameter. gamma_shape = gamma.axes.lengths[0] bias_shape = bias.axes.lengths[0] weights_shape = [gamma_shape, bias_shape] weights_shape_arg = ((ct.c_int) * len(weights_shape))(*weights_shape) if (axis_len_5d): (delta_layout, mkl_axes) = get_mkl_layout( self.mkldnn, unflatten(delta), [4, 0, 2, 3], True) if (axis_len_5d): (fprop_src_layout, _) = get_mkl_layout( self.mkldnn, unflatten(fprop_src), [4, 0, 2, 3], True) fprop_src_layout = self.mkldnn.op_layouts.get(fprop_src.name) mean_layout = None if mean.name in self.mkldnn.kernels: mean_layout = self.mkldnn.op_layouts[mean.name] variance_layout = None if variance.name in self.mkldnn.kernels: variance_layout = self.mkldnn.op_layouts[variance.name] op_id = len(self.mkldnn.kernels) self.mkldnn.kernels[op.name] = self.mkldnn.create_empty_kernel(op_id) self.mkldnn.batchnorm_bprop_kernel( self.mkldnn.mkldnn_engine, len(delta_shape), len(outputs_shape), len(weights_shape), mean_dims, variance_dims, delta_shape_arg, outputs_shape_arg, weights_shape_arg, mean_size, variance_size, op.fprop.eps, fprop_src_layout, None, mean_layout, variance_layout, delta_layout, data_type, self.mkldnn.kernels[ op.fprop.forwarded.name], self.mkldnn.kernels[ op.name]) mkl_order = get_order_from_axes(unflatten(delta).axes, mkl_axes) out_axes = get_axes_mkl_order(unflatten(op).axes, mkl_order) self.set_mkl_layout_data(op, out_axes) dbg_print_kernel(self.mkldnn, op, op_id)