示例#1
0
 def init_mkldnn_reorder(self, op):
     (mkl_layout, mkl_axes) = op.in_layout
     check_flatten = False
     for axis_indx, each_axis in enumerate(op.axes):
         if isinstance(
                 each_axis,
                 FlattenedAxis) and not (mkl_axes[axis_indx].is_flattened):
             check_flatten = True
     if check_flatten:
         mkl_axes_order = get_order_from_axes(unflatten(op).axes, mkl_axes)
     else:
         mkl_axes_order = get_order_from_axes(op.axes, mkl_axes)
     (out_layout, _) = get_mkl_layout(self.mkldnn, op, mkl_axes_order, True)
     ndims = len(mkl_axes)
     if check_flatten:
         dims = get_size_mkl_order(unflatten(op).axes, mkl_axes_order)
     else:
         dims = get_size_mkl_order(op.axes, mkl_axes_order)
     dims_arg = ((ct.c_int) * ndims)(*dims)
     op_id = len(self.mkldnn.kernels)
     self.mkldnn.kernels[op.name] = self.mkldnn.create_empty_kernel(op_id)
     self.mkldnn.reorder_kernel(self.mkldnn.mkldnn_engine, ndims, dims_arg,
                                self.mkldnn.datatype[op.dtype.type],
                                self.mkldnn.memory_format['blocked'],
                                mkl_layout, out_layout,
                                self.mkldnn.kernels[op.name])
     dbg_print_kernel(self.mkldnn, op, op_id)
示例#2
0
    def visit(self, op, inputs, gamma, bias, epsilon, mean, variance):
        # unflatten the inputs and extract C H W N params
        if isinstance(inputs, Flatten):
            unflatten_inputs = unflatten(inputs)
            # Sanity check tensor shapes
            if (len(unflatten_inputs.axes.lengths) != 5):
                return

        # Only single precision float supported for now
        if op.dtype != np.float32:
            return

        data_type = self.mkldnn.datatype[op.dtype.type]
        inputs_shape = get_size_mkl_order(unflatten(inputs).axes, [4, 0, 2, 3])
        mean_size = mean.axes.lengths[0]
        mean_dims = 1
        gamma_shape = gamma.axes.lengths[0]
        bias_shape = bias.axes.lengths[0]
        variance_size = variance.axes.lengths[0]
        variance_dims = 1
        outputs_shape = op.axes.lengths

        # weights is 2 dimensional, 1-st dimension contains gamma parameter, 2-nd
        # dimension contains beta parameter.
        weights_shape = [gamma_shape, bias_shape]
        weights_shape_arg = ((ct.c_int) * len(weights_shape))(*weights_shape)
        input_shape_arg = ((ct.c_int) * len(inputs_shape))(*inputs_shape)
        outputs_shape_arg = ((ct.c_int) * len(outputs_shape))(*outputs_shape)

        (inputs_layout, mkl_axes) = get_mkl_layout(
            self.mkldnn, unflatten_inputs, [4, 0, 2, 3], True)
        mean_layout = None
        variance_layout = None

        op_id = len(self.mkldnn.kernels)
        self.mkldnn.kernels[op.name] = self.mkldnn.create_empty_kernel(op_id)

        self.mkldnn.batchnorm_fprop_kernel(
            self.mkldnn.mkldnn_engine,
            len(inputs_shape),
            len(outputs_shape),
            len(weights_shape),
            mean_dims,
            variance_dims,
            mean_size,
            variance_size,
            input_shape_arg,
            weights_shape_arg,
            outputs_shape_arg,
            op.eps,
            inputs_layout,
            None,
            mean_layout,
            variance_layout,
            data_type,
            self.mkldnn.kernels[
                op.name])

        self.set_mkl_layout_data(op, mkl_axes)
        dbg_print_kernel(self.mkldnn, op, op_id)
示例#3
0
    def visit(self, op, arg):
        p_axis = arg.axes.find_by_name(op.metadata['parallel'].name)
        assert len(p_axis) > 0, "Invalid to scatter a scalar"
        if arg.axes.index(p_axis[0]) > 0:
            arg = axes_with_order(arg, p_axis + (arg.axes - p_axis))
            arg = flatten_at(arg, 0)
            arg = unflatten(arg)

            # replace the ops
            new_op = op.copy_with_new_args([arg])
            self.replace_op(op, new_op)
示例#4
0
    def visit(self, op):
        x, y = op.args
        x_reduction_axes = op.x_reduction_axes
        y_reduction_axes = op.y_reduction_axes
        out_axes = op.axes
        if len(x_reduction_axes) == 0:
            d = make_axis(1)
            x_reduction_axes = make_axes((d, ))
            y_reduction_axes = x_reduction_axes
            x = broadcast(x, x.axes + x_reduction_axes)
            y = broadcast(y, y_reduction_axes + y.axes)

        if x.is_scalar:
            temp = x
            x = y
            y = temp
        if y.is_scalar:
            if x.is_scalar:
                out = x.scalar_op * y.scalar_op
                if len(x_reduction_axes) > 0:
                    out = out * x_reduction_axes.size
                out = broadcast(out, op.axes)
            else:
                out = Sum(x, x_reduction_axes) * y.scalar_op
            out = broadcast(out, op.axes)
        else:
            x_rem_axes = x.axes - x_reduction_axes
            x = axes_with_order(x, x_rem_axes + x_reduction_axes)

            y_rem_axes = y.axes - y_reduction_axes
            y = axes_with_order(y, y_reduction_axes + y_rem_axes)

            x = flatten_at(x, len(x.axes) - len(x_reduction_axes))
            y = flatten_at(y, len(y_reduction_axes))

            if len(out_axes) == 0:
                out = DotOneDimensional(x, y, axes=())
            elif len(x.axes) == 1:
                y = Transpose(y)
                out = DotTwoByOne(y, x, axes=y.axes[0])
            elif len(y.axes) == 1:
                out = DotTwoByOne(x, y, axes=x.axes[0])
            else:
                out = DotTwoDimensional(
                    x,
                    y,
                    axes=([op.x_out_axes.flatten(),
                           op.y_out_axes.flatten()]))

            out = unflatten(out)
            out = ReorderAxes(out, out_axes)

        self.replace_op(op, out)
示例#5
0
    def visit(self, op):
        x, y = op.args
        reduction_axes = op.reduction_axes
        out_axes = op.axes
        if len(reduction_axes) == 0:
            # TODO: this is a weird case, should we really support it?
            d = make_axis(1)
            reduction_axes = make_axes((d, ))
            x = broadcast(x, x.axes + reduction_axes)
            y = broadcast(y, reduction_axes + y.axes)

        if x.is_scalar:
            x, y = y, x

        if y.is_scalar:
            if x.is_scalar:
                out = x.scalar_op * y.scalar_op
                if len(reduction_axes) > 0:
                    out = out * reduction_axes.size
                out = broadcast(out, op.axes)
            else:
                out = Sum(x, reduction_axes) * y.scalar_op
            out = broadcast(out, op.axes)
        else:
            # move reduction_axes to end
            x = axes_with_order(x, (x.axes - reduction_axes) + reduction_axes)
            # move reduction axes to front
            y = axes_with_order(y, reduction_axes + (y.axes - reduction_axes))

            # flatten non-reduction axes together and reduction axes together
            x = flatten_at(x, len(x.axes) - len(reduction_axes))
            # flatten non-reduction axes together and reduction axes together
            y = flatten_at(y, len(reduction_axes))

            if len(out_axes) == 0:
                out = DotLowDimension(x, y, axes=())
            elif len(x.axes) == 1:
                y = Transpose(y)
                out = DotLowDimension(y, x, axes=y.axes[0])
            elif len(y.axes) == 1:
                out = DotLowDimension(x, y, axes=x.axes[0])
            else:
                out = DotLowDimension(x,
                                      y,
                                      axes=([
                                          op.x_out_axes.flatten(True),
                                          op.y_out_axes.flatten(True)
                                      ]))

            out = unflatten(out)
            out = ReorderAxes(out, out_axes)

        self.replace_op(op, out)
示例#6
0
    def visit(self, op, arg):
        p_axis = arg.axes.find_by_name(
            op.send_node().metadata['parallel'].name)
        if len(p_axis) == 0:
            pass

        elif arg.axes.index(p_axis[0]) > 0:
            arg = axes_with_order(arg, p_axis + (arg.axes - p_axis))
            arg = flatten_at(arg, 0)
            arg = unflatten(arg)

            # replace the ops
            new_op = op.copy_with_new_args([arg])
            self.replace_op(op, new_op)
示例#7
0
    def visit(self, op, arg):
        if 'parallel' not in arg.metadata:
            return

        p_axis = op.axes.find_by_name(arg.metadata['parallel'].name)
        if len(p_axis) == 0:
            self.replace_op(op, arg)

        elif op.axes.index(p_axis[0]) > 0:
            arg = axes_with_order(arg, op.axes)
            arg = flatten_at(arg, 0)
            arg = unflatten(arg)

            # replace the ops
            self.replace_op(op, arg)
示例#8
0
def get_axes_mkl_order(axes, order):
    axes_list = []
    flattend_axis_flag = False
    for axis in axes:
        if (axis.is_flattened and len(order) > 2):
            unflattend_axis = unflatten(axis).axes
            for indx in range(len(unflattend_axis)):
                axes_list.append(unflattend_axis[indx])
                flattend_axis_flag = True
        else:
            axes_list.append(axis)
    if flattend_axis_flag:
        return [axes_list[index] for index in order]
    else:
        return [axes[index] for index in order]
示例#9
0
    def visit(self, op, delta, fprop_src, gamma, bias, mean, variance):
        axis_len_5d = False
        # Only single precision float supported for now
        if op.dtype != np.float32:
            return
        # Sanity check tensor shapes
        if (len(op.axes.lengths) == 2):
            if isinstance(op.axes[1], FlattenedAxis):
                C, Flatten_axis = op.axes
                if (len(unflatten(Flatten_axis).axes.lengths) != 4):
                    return
                else:
                    outputs_shape = get_size_mkl_order(unflatten(op).axes, [4, 0, 2, 3])
                    outputs_shape_arg = ((ct.c_int) * len(outputs_shape))(*outputs_shape)
                    axis_len_5d = True
            else:
                return
        data_type = self.mkldnn.datatype[op.dtype.type]
        mean_dims = 1
        variance_dims = 1
        mean_size = mean.axes.lengths[0]
        variance_size = variance.axes.lengths[0]

        delta_shape = get_size_mkl_order(unflatten(delta).axes, [4, 0, 2, 3])
        delta_shape_arg = ((ct.c_int) * len(delta_shape))(*delta_shape)

        # weights is 2 dimensional, 1-st dimension contains gamma parameter, 2-nd
        # dimension contains beta parameter.
        gamma_shape = gamma.axes.lengths[0]
        bias_shape = bias.axes.lengths[0]
        weights_shape = [gamma_shape, bias_shape]
        weights_shape_arg = ((ct.c_int) * len(weights_shape))(*weights_shape)

        if (axis_len_5d):
            (delta_layout, mkl_axes) = get_mkl_layout(
                self.mkldnn, unflatten(delta), [4, 0, 2, 3], True)
        if (axis_len_5d):
            (fprop_src_layout, _) = get_mkl_layout(
                self.mkldnn, unflatten(fprop_src), [4, 0, 2, 3], True)

        fprop_src_layout = self.mkldnn.op_layouts.get(fprop_src.name)
        mean_layout = None
        if mean.name in self.mkldnn.kernels:
            mean_layout = self.mkldnn.op_layouts[mean.name]

        variance_layout = None
        if variance.name in self.mkldnn.kernels:
            variance_layout = self.mkldnn.op_layouts[variance.name]

        op_id = len(self.mkldnn.kernels)
        self.mkldnn.kernels[op.name] = self.mkldnn.create_empty_kernel(op_id)

        self.mkldnn.batchnorm_bprop_kernel(
            self.mkldnn.mkldnn_engine,
            len(delta_shape),
            len(outputs_shape),
            len(weights_shape),
            mean_dims,
            variance_dims,
            delta_shape_arg,
            outputs_shape_arg,
            weights_shape_arg,
            mean_size,
            variance_size,
            op.fprop.eps,
            fprop_src_layout,
            None,
            mean_layout,
            variance_layout,
            delta_layout,
            data_type,
            self.mkldnn.kernels[
                op.fprop.forwarded.name],
            self.mkldnn.kernels[
                op.name])
        mkl_order = get_order_from_axes(unflatten(delta).axes, mkl_axes)
        out_axes = get_axes_mkl_order(unflatten(op).axes, mkl_order)
        self.set_mkl_layout_data(op, out_axes)
        dbg_print_kernel(self.mkldnn, op, op_id)