コード例 #1
0
ファイル: cpufusion.py プロジェクト: rsumner31/ngraph
    def fuse_conv_and_bias_callback(self, op, label_map_op_list):
        """
        Callback function that handles fusion for Conv + bias  pattern
        """
        for (label_map, op) in label_map_op_list:
            map_roles = label_map[self.map_roles_label]
            conv_op = self.op_arg(map_roles, 0)
            bias = label_map[self.conv_bias_label]
            new_op_map = {}
            if isinstance(conv_op, ConvolutionOp):
                conv_new_op = ConvolutionOp(conv_op.conv_params, self.op_arg(conv_op, 0),
                                            self.op_arg(conv_op, 1), bias, axes=conv_op.axes)
                new_op_map[conv_op] = conv_new_op
                # Create ops that are downstream to convolution but still upstream of the add 'op'
                prev_op = self.op_arg(op, 0)
                upstream_ops = []
                while prev_op != conv_op:
                    upstream_ops.append(prev_op)
                    prev_op = self.op_arg(prev_op, 0)
                for old_op in reversed(upstream_ops):
                    new_arg = new_op_map[self.op_arg(old_op, 0)]
                    if isinstance(old_op, MapRolesOp):
                        new_op_map[old_op] = MapRolesOp(new_arg, old_op.axes_map)
                    elif isinstance(old_op, TensorSliceOp):
                        new_op_map[old_op] = TensorSliceOp(new_arg, old_op.slices, old_op.axes)
                    elif isinstance(old_op, ReorderAxes):
                        new_op_map[old_op] = ReorderAxes(new_arg, old_op.axes)

                self.replace_op(op, new_op_map[self.op_arg(op, 0)])
                self.replace_op(conv_op, conv_new_op)
コード例 #2
0
    def visit(self, op):
        x, y = op.args
        x_reduction_axes = op.x_reduction_axes
        y_reduction_axes = op.y_reduction_axes
        out_axes = op.axes
        if len(x_reduction_axes) == 0:
            d = make_axis(1)
            x_reduction_axes = make_axes((d, ))
            y_reduction_axes = x_reduction_axes
            x = broadcast(x, x.axes + x_reduction_axes)
            y = broadcast(y, y_reduction_axes + y.axes)

        if x.is_scalar:
            temp = x
            x = y
            y = temp
        if y.is_scalar:
            if x.is_scalar:
                out = x.scalar_op * y.scalar_op
                if len(x_reduction_axes) > 0:
                    out = out * x_reduction_axes.size
                out = broadcast(out, op.axes)
            else:
                out = Sum(x, x_reduction_axes) * y.scalar_op
            out = broadcast(out, op.axes)
        else:
            x_rem_axes = x.axes - x_reduction_axes
            x = axes_with_order(x, x_rem_axes + x_reduction_axes)

            y_rem_axes = y.axes - y_reduction_axes
            y = axes_with_order(y, y_reduction_axes + y_rem_axes)

            x = flatten_at(x, len(x.axes) - len(x_reduction_axes))
            y = flatten_at(y, len(y_reduction_axes))

            if len(out_axes) == 0:
                out = DotOneDimensional(x, y, axes=())
            elif len(x.axes) == 1:
                y = Transpose(y)
                out = DotTwoByOne(y, x, axes=y.axes[0])
            elif len(y.axes) == 1:
                out = DotTwoByOne(x, y, axes=x.axes[0])
            else:
                out = DotTwoDimensional(
                    x,
                    y,
                    axes=([op.x_out_axes.flatten(),
                           op.y_out_axes.flatten()]))

            out = unflatten(out)
            out = ReorderAxes(out, out_axes)

        self.replace_op(op, out)
コード例 #3
0
    def visit(self, op):
        x, y = op.args
        reduction_axes = op.reduction_axes
        out_axes = op.axes
        if len(reduction_axes) == 0:
            # TODO: this is a weird case, should we really support it?
            d = make_axis(1)
            reduction_axes = make_axes((d, ))
            x = broadcast(x, x.axes + reduction_axes)
            y = broadcast(y, reduction_axes + y.axes)

        if x.is_scalar:
            x, y = y, x

        if y.is_scalar:
            if x.is_scalar:
                out = x.scalar_op * y.scalar_op
                if len(reduction_axes) > 0:
                    out = out * reduction_axes.size
                out = broadcast(out, op.axes)
            else:
                out = Sum(x, reduction_axes) * y.scalar_op
            out = broadcast(out, op.axes)
        else:
            # move reduction_axes to end
            x = axes_with_order(x, (x.axes - reduction_axes) + reduction_axes)
            # move reduction axes to front
            y = axes_with_order(y, reduction_axes + (y.axes - reduction_axes))

            # flatten non-reduction axes together and reduction axes together
            x = flatten_at(x, len(x.axes) - len(reduction_axes))
            # flatten non-reduction axes together and reduction axes together
            y = flatten_at(y, len(reduction_axes))

            if len(out_axes) == 0:
                out = DotLowDimension(x, y, axes=())
            elif len(x.axes) == 1:
                y = Transpose(y)
                out = DotLowDimension(y, x, axes=y.axes[0])
            elif len(y.axes) == 1:
                out = DotLowDimension(x, y, axes=x.axes[0])
            else:
                out = DotLowDimension(x,
                                      y,
                                      axes=([
                                          op.x_out_axes.flatten(True),
                                          op.y_out_axes.flatten(True)
                                      ]))

            out = unflatten(out)
            out = ReorderAxes(out, out_axes)

        self.replace_op(op, out)