Esempio n. 1
0
    def visit(self, op):
        x, y = op.args
        reduction_axes = op.reduction_axes
        out_axes = op.axes
        if len(reduction_axes) == 0:
            # TODO: this is a weird case, should we really support it?
            d = make_axis(1)
            reduction_axes = make_axes((d, ))
            x = broadcast(x, x.axes + reduction_axes)
            y = broadcast(y, reduction_axes + y.axes)

        if x.is_scalar:
            x, y = y, x

        if y.is_scalar:
            if x.is_scalar:
                out = x.scalar_op * y.scalar_op
                if len(reduction_axes) > 0:
                    out = out * reduction_axes.size
                out = broadcast(out, op.axes)
            else:
                out = Sum(x, reduction_axes) * y.scalar_op
            out = broadcast(out, op.axes)
        else:
            # move reduction_axes to end
            x = axes_with_order(x, (x.axes - reduction_axes) + reduction_axes)
            # move reduction axes to front
            y = axes_with_order(y, reduction_axes + (y.axes - reduction_axes))

            # flatten non-reduction axes together and reduction axes together
            x = flatten_at(x, len(x.axes) - len(reduction_axes))
            # flatten non-reduction axes together and reduction axes together
            y = flatten_at(y, len(reduction_axes))

            if len(out_axes) == 0:
                out = DotLowDimension(x, y, axes=())
            elif len(x.axes) == 1:
                y = Transpose(y)
                out = DotLowDimension(y, x, axes=y.axes[0])
            elif len(y.axes) == 1:
                out = DotLowDimension(x, y, axes=x.axes[0])
            else:
                out = DotLowDimension(x,
                                      y,
                                      axes=([
                                          op.x_out_axes.flatten(True),
                                          op.y_out_axes.flatten(True)
                                      ]))

            out = unflatten(out)
            out = ReorderAxes(out, out_axes)

        self.replace_op(op, out)
Esempio n. 2
0
    def construct_batchnorm_bprop_pattern(self):
        """
        Generate graph op that represents a pattern for batchnorm backprop operation.
            dgamma = np.sum(delta * xhat)
            dbeta = np.sum(delta)
            dx = gamma_scale * (delta - (xhat * dgamma + dbeta) / m)
            In this pattern we are only generating the pattern for  dx.
        Returns:
               Single pattern that matches batchnorm bprop op
        """
        self.batchnorm_bprop_input_tensor = "input_tensor"
        self.batchnorm_bprop_delta = "delta"
        self.batchnorm_bprop_gamma_label = "gamma"
        self.batchnorm_bprop_var_label = "var"
        self.batchnorm_bprop_ivar_label = "ivar"
        self.batchnorm_bprop_xmu1_label = "xmu1"
        self.batchnorm_bprop_xmu2_label = "xmu2"
        self.batchnorm_bprop_negative_inverse_sqrtvar = "negative_inverse_sqrtvar"
        self.batchnorm_bprop_inverse_sqrtvar = "inverse_sqrtvar"
        self.batchnorm_bprop_sqrtvar_label = "sqrtvar"
        self.batchnorm_bprop_sqrsum = "sqrsum"
        self.batchnorm_bprop_mean_1 = "mean_1"
        self.batchnorm_bprop_mean_2 = "mean_2"
        self.batchnorm_bprop_input_sum = "input_sum"

        # bind the op's to the label
        input_tensor = PatternLabelOp(self.batchnorm_bprop_input_tensor,
                                      (lambda op: isinstance(op, Flatten)))
        var = PatternLabelOp(self.batchnorm_bprop_var_label,
                             (lambda op: isinstance(op, Divide)))
        gamma = PatternLabelOp(self.batchnorm_bprop_gamma_label,
                               (lambda op: isinstance(op, BroadcastOp)))
        delta = PatternLabelOp(self.batchnorm_bprop_delta,
                               (lambda op: isinstance(op, Flatten)))
        xmu1 = PatternLabelOp(self.batchnorm_bprop_xmu1_label,
                              (lambda op: isinstance(op, Subtract)))
        xmu2 = PatternLabelOp(self.batchnorm_bprop_xmu2_label,
                              (lambda op: isinstance(op, Subtract)))
        ivar = PatternLabelOp(self.batchnorm_bprop_ivar_label,
                              (lambda op: isinstance(op, BroadcastOp)))
        negative_inverse_sqrtvar = PatternLabelOp(
            self.batchnorm_bprop_negative_inverse_sqrtvar,
            (lambda op: isinstance(op, NegativeOp)))
        inverse_sqrtvar = PatternLabelOp(
            self.batchnorm_bprop_inverse_sqrtvar,
            (lambda op: isinstance(op, ReciprocalOp)))
        sqrtvar = PatternLabelOp(self.batchnorm_bprop_sqrtvar_label,
                                 (lambda op: isinstance(op, SqrtOp)))
        sqrsum = PatternLabelOp(self.batchnorm_bprop_sqrsum,
                                (lambda op: isinstance(op, Sum)))
        mean_1 = PatternLabelOp(self.batchnorm_bprop_mean_1,
                                (lambda op: isinstance(op, Divide)))
        mean_2 = PatternLabelOp(self.batchnorm_bprop_mean_2,
                                (lambda op: isinstance(op, Divide)))
        input_sum = PatternLabelOp(self.batchnorm_bprop_input_sum,
                                   (lambda op: isinstance(op, Sum)))

        constant_point_5 = ng.constant(0.5)
        constant_point_5_w_broadcast = ng.PatternSkipOp(
            constant_point_5, lambda op: isinstance(op, BroadcastOp))
        constant_two = ng.constant(2)
        constant_two_w_broadcast = ng.PatternSkipOp(
            constant_two, lambda op: isinstance(op, BroadcastOp))
        # construct the pattern
        dxhat = Multiply(gamma, delta)
        # divar = np.sum(dxhat*xmu, axis=0)
        divar = Sum(Multiply(dxhat, xmu1))
        # dxmu1 = dxhat * ivar
        dxmu1 = Multiply(dxhat, ivar)
        # dsqrtvar = -1. /(sqrtvar**2) * divar
        dsqrtvar = Multiply(
            Multiply(inverse_sqrtvar, negative_inverse_sqrtvar), divar)
        # dvar = 0.5 * 1. /np.sqrt(var+eps) * dsqrtvar
        dvar = Divide(Multiply(dsqrtvar, constant_point_5_w_broadcast),
                      sqrtvar)
        # dsq = 1. / N * np.ones((N, D)) * dvar
        dsq = Divide(Multiply(dvar, var), sqrsum)
        dsq_w_broadcast = ng.PatternSkipOp(
            dsq, (lambda op: isinstance(op, BroadcastOp)))
        # dxmu2 = 2 * xmu * dsq
        dxmu2 = Multiply(xmu2,
                         Multiply(constant_two_w_broadcast, dsq_w_broadcast))

        # dx1 = (dxmu1 + dxmu2)
        # dmu = -1 * np.sum(dxmu1 + dxmu2, axis=0)
        # dx2 = 1. /N * np.ones((N,D)) * dmu
        # dx = dx1 + dx2
        dxmu2_mul = Multiply(Sum(ng.negative(dxmu2)), mean_2)
        dxmu2_div = Divide(dxmu2_mul, input_sum)
        dxmu2_div_w_broadcast = ng.PatternSkipOp(
            dxmu2_div, (lambda op: isinstance(op, BroadcastOp)))
        dxmu2_div_plus_dxmu2 = Add(dxmu2_div_w_broadcast, dxmu2)

        dx1 = Add(dxmu1, dxmu2_div_plus_dxmu2)
        dxmu1_mul = Multiply(Sum(ng.negative(dxmu1)), mean_1)
        dxmu1_div = Divide(dxmu1_mul, Sum(input_tensor))
        dxmu1_div_w_broadcast = ng.PatternSkipOp(
            dxmu1_div, (lambda op: isinstance(op, BroadcastOp)))
        dx = Add(dxmu1_div_w_broadcast, dx1)
        return dx