예제 #1
0
    def construct_relu_fprop_pattern(self):
        """
        Generate graph op that represents a pattern for Relu operation.
        max(val, 0) + slope * min (0, val)

        Note that there could be multiple patterns:
        Pattern 1 - max(x, 0) + slope * min (0, x)
        Pattern 2 - max(0, x) + slope * min (0, x)
        ..
        But we generate only 1 and match_pattern takes care of matching all
        permutations.

        Returns:
            Single pattern that matches Relu fprop op

        """
        zero = ng.constant(0)
        zero_w_broadcast = PatternSkipOp(
            zero, (lambda op: isinstance(op, BroadcastOp)))
        # We want to match x tensor and slope for Relu.
        self.relu_fwd_slope_label = "S"
        self.relu_fwd_x_label = "X"
        # We bind op to X unconditionally.
        x = PatternLabelOp(self.relu_fwd_x_label)
        max_op = Maximum(x, zero_w_broadcast)
        # We bind slope op to S only if it is scalar.
        slope_label_op = PatternLabelOp(self.relu_fwd_slope_label,
                                        (lambda op: op.is_scalar))
        slope = PatternSkipOp(slope_label_op,
                              (lambda op: isinstance(op, BroadcastOp)))
        min_op = Minimum(zero_w_broadcast, x)
        mul_op = Multiply(slope, min_op)
        add_op = Add(max_op, mul_op)
        return add_op
예제 #2
0
    def construct_conv_and_bias_pattern(self):
        """
        Pattern - Add(Convolution, Bias).
        Returns:
            Single pattern that matches Add(Convolution, Bias) pattern.
        """

        self.conv_bias_label = "B"
        self.conv_op_label = "C"

        bias_label_op = PatternLabelOp(
            self.conv_bias_label, (lambda op: op.is_scalar)
            and (lambda op: isinstance(op, TensorValueOp)))
        bias = PatternSkipOp(bias_label_op,
                             (lambda op: isinstance(op, BroadcastOp))
                             or (lambda op: isinstance(op, ContiguousOp)))

        conv_label_op = PatternLabelOp(
            self.conv_op_label, (lambda op: isinstance(op, ConvolutionOp)))

        conv_op = PatternSkipOp(conv_label_op,
                                (lambda op: isinstance(op, MapRolesOp)))

        add_op = Add(conv_op, bias)
        return add_op
예제 #3
0
    def construct_conv_and_bias_pattern(self):
        """
        Pattern - Add(Convolution, Bias).
        Returns:
            Single pattern that matches Add(Convolution, Bias) pattern.
        """

        self.conv_bias_label = "B"
        self.map_roles_label = "M"

        bias_label_op = PatternLabelOp(self.conv_bias_label,
                                       (lambda op: op.is_scalar) and
                                       (lambda op: isinstance(op, TensorValueOp)))
        bias = PatternSkipOp(bias_label_op,
                             (lambda op: isinstance(op, BroadcastOp)) or
                             (lambda op: isinstance(op, ContiguousOp)))
        reorder_bias = PatternSkipOp(bias, lambda op: isinstance(op, ReorderAxes))
        map_roles = PatternLabelOp(self.map_roles_label,
                                   (lambda op: isinstance(op, MapRolesOp)))
        tslice = PatternSkipOp(map_roles, lambda op: isinstance(op, TensorSliceOp))
        reorder_conv = PatternSkipOp(tslice, lambda op: isinstance(op, ReorderAxes))
        add_op = Add(reorder_conv, reorder_bias)
        return add_op
예제 #4
0
    def construct_innerproduct_and_bias_pattern(self):
        """
        Pattern - Add(DotLowDimension, Bias).
        Returns:
            Single pattern that matches Add(DotLowDimension, Bias) pattern.
        """

        self.dot_x_label = "X"
        self.dot_y_label = "Y"
        self.bias_label = "B"
        x = PatternLabelOp(self.dot_x_label)
        y = PatternLabelOp(self.dot_y_label)

        bias_label_op = PatternLabelOp(
            self.bias_label, (lambda op: op.is_scalar)
            and (lambda op: isinstance(op, TensorValueOp)))
        bias = PatternSkipOp(bias_label_op,
                             (lambda op: isinstance(op, BroadcastOp))
                             or (lambda op: isinstance(op, ContiguousOp)))

        dot_op = PatternSkipOp(DotOp(x, y, None),
                               (lambda op: isinstance(op, MapRolesOp)))
        add_op = Add(dot_op, bias)
        return add_op
예제 #5
0
    def construct_batchnorm_fprop_pattern(self):
        """
        Generate graph op that represents a pattern for batchnorm fprop operation.
        self.gamma * ((in_obj - xmean) * ng.reciprocal(ng.sqrt(xvar + self.eps))) + self.beta
        Returns:
               Single pattern that matches batchnorm fprop op
        """
        self.batchnorm_fprop_input_tensor_label = "in_obj"
        self.batchnorm_fprop_gamma_label = "gamma"
        self.batchnorm_fprop_beta_label = "beta"
        self.batchnorm_fprop_variance_label = "variance"
        self.batchnorm_fprop_epsilon_label = "epsilon"
        self.batchnorm_fprop_mean_label = "mean"

        # bind the label to the op's which needed to be updated in the dict
        in_obj = PatternLabelOp(self.batchnorm_fprop_input_tensor_label,
                                (lambda op: isinstance(op, ContiguousOp)))
        flatten_tensor = PatternSkipOp(in_obj,
                                       (lambda op: isinstance(op, Flatten)))
        gamma = PatternLabelOp(self.batchnorm_fprop_gamma_label,
                               (lambda op: isinstance(op, BroadcastOp)))
        beta = PatternLabelOp(self.batchnorm_fprop_beta_label,
                              (lambda op: isinstance(op, BroadcastOp)))
        variance = PatternLabelOp(self.batchnorm_fprop_variance_label,
                                  (lambda op: isinstance(op, Divide)))
        epsilon = PatternLabelOp(self.batchnorm_fprop_epsilon_label,
                                 (lambda op: isinstance(op, BroadcastOp)))
        mean = PatternLabelOp(self.batchnorm_fprop_mean_label,
                              (lambda op: isinstance(op, Divide)))

        # construct the fprop batchnorm pattern matching the computation graph
        # ng.sqrt(xvar + self.eps)
        SqrtofVarianceAndEps = ng.sqrt(ng.add(variance, epsilon))
        # ng.reciprocal(ng.sqrt(xvar + self.eps))
        reciprocal_op = ng.reciprocal(SqrtofVarianceAndEps)
        reciprocal_op_w_braodcast = ng.PatternSkipOp(reciprocal_op,
                                                     lambda op: isinstance(op, BroadcastOp))

        mean_bcast = ng.PatternSkipOp(mean, lambda op: isinstance(op, BroadcastOp))
        # (in_obj - xmean) * ng.reciprocal(ng.sqrt(xvar + self.eps))
        mul_op_1 = ng.multiply(ng.subtract(flatten_tensor, mean_bcast), reciprocal_op_w_braodcast)
        # "self.gamma * ((in_obj - xmean) * ng.reciprocal(ng.sqrt(xvar + self.eps)))
        MultiplyGamma = ng.multiply(mul_op_1, gamma)
        # self.gamma * ((in_obj - xmean) * ng.reciprocal(ng.sqrt(xvar + self.eps))) + self.beta
        AddBeta = ng.Unflatten(ng.Add(MultiplyGamma, beta))
        return AddBeta
예제 #6
0
    def construct_batchnorm_bprop_pattern(self):
        """
        Generate graph op that represents a pattern for batchnorm backprop operation.
            dgamma = np.sum(delta * xhat)
            dbeta = np.sum(delta)
            dx = gamma_scale * (delta - (xhat * dgamma + dbeta) / m)
            In this pattern we are only generating the pattern for  dx.
        Returns:
               Single pattern that matches batchnorm bprop op
        """
        self.batchnorm_bprop_input_tensor = "input_tensor"
        self.batchnorm_bprop_delta = "delta"
        self.batchnorm_bprop_gamma_label = "gamma"
        self.batchnorm_bprop_var_label = "var"
        self.batchnorm_bprop_ivar_label = "ivar"
        self.batchnorm_bprop_xmu1_label = "xmu1"
        self.batchnorm_bprop_xmu2_label = "xmu2"
        self.batchnorm_bprop_negative_inverse_sqrtvar = "negative_inverse_sqrtvar"
        self.batchnorm_bprop_inverse_sqrtvar = "inverse_sqrtvar"
        self.batchnorm_bprop_sqrtvar_label = "sqrtvar"
        self.batchnorm_bprop_sqrsum = "sqrsum"
        self.batchnorm_bprop_mean_1 = "mean_1"
        self.batchnorm_bprop_mean_2 = "mean_2"
        self.batchnorm_bprop_input_sum = "input_sum"

        # bind the op's to the label
        input_tensor = PatternLabelOp(self.batchnorm_bprop_input_tensor,
                                      (lambda op: isinstance(op, ContiguousOp)))
        flatten_tensor = PatternSkipOp(input_tensor,
                                       (lambda op: isinstance(op, Flatten)))
        var = PatternLabelOp(self.batchnorm_bprop_var_label,
                             (lambda op: isinstance(op, Divide)))
        gamma = PatternLabelOp(self.batchnorm_bprop_gamma_label,
                               (lambda op: isinstance(op, BroadcastOp)))
        delta = PatternLabelOp(self.batchnorm_bprop_delta,
                               (lambda op: isinstance(op, Flatten)))
        xmu1 = PatternLabelOp(self.batchnorm_bprop_xmu1_label,
                              (lambda op: isinstance(op, Subtract)))
        xmu2 = PatternLabelOp(self.batchnorm_bprop_xmu2_label,
                              (lambda op: isinstance(op, Subtract)))
        ivar = PatternLabelOp(self.batchnorm_bprop_ivar_label,
                              (lambda op: isinstance(op, BroadcastOp)))
        negative_inverse_sqrtvar = PatternLabelOp(self.batchnorm_bprop_negative_inverse_sqrtvar,
                                                  (lambda op: isinstance(op, NegativeOp)))
        inverse_sqrtvar = PatternLabelOp(self.batchnorm_bprop_inverse_sqrtvar,
                                         (lambda op: isinstance(op, ReciprocalOp)))
        sqrtvar = PatternLabelOp(self.batchnorm_bprop_sqrtvar_label,
                                 (lambda op: isinstance(op, SqrtOp)))
        sqrsum = PatternLabelOp(self.batchnorm_bprop_sqrsum,
                                (lambda op: isinstance(op, Sum)))
        mean_1 = PatternLabelOp(self.batchnorm_bprop_mean_1,
                                (lambda op: isinstance(op, Divide)))
        mean_2 = PatternLabelOp(self.batchnorm_bprop_mean_2,
                                (lambda op: isinstance(op, Divide)))
        input_sum = PatternLabelOp(self.batchnorm_bprop_input_sum,
                                   (lambda op: isinstance(op, Sum)))

        constant_point_5 = ng.constant(0.5)
        constant_point_5_w_broadcast = ng.PatternSkipOp(constant_point_5,
                                                        lambda op: isinstance(op, BroadcastOp))
        constant_two = ng.constant(2)
        constant_two_w_broadcast = ng.PatternSkipOp(constant_two,
                                                    lambda op: isinstance(op, BroadcastOp))
        # construct the pattern
        dxhat = Multiply(gamma, delta)
        # divar = np.sum(dxhat*xmu, axis=0)
        divar = Sum(Multiply(dxhat, xmu1))
        # dxmu1 = dxhat * ivar
        dxmu1 = Multiply(dxhat, ivar)
        # dsqrtvar = -1. /(sqrtvar**2) * divar
        dsqrtvar = Multiply(Multiply(inverse_sqrtvar, negative_inverse_sqrtvar), divar)
        # dvar = 0.5 * 1. /np.sqrt(var+eps) * dsqrtvar
        dvar = Divide(Multiply(dsqrtvar, constant_point_5_w_broadcast), sqrtvar)
        # dsq = 1. / N * np.ones((N, D)) * dvar
        dsq = Divide(Multiply(dvar, var), sqrsum)
        dsq_w_broadcast = ng.PatternSkipOp(dsq,
                                           (lambda op: isinstance(op, BroadcastOp)))
        # dxmu2 = 2 * xmu * dsq
        dxmu2 = Multiply(xmu2, Multiply(constant_two_w_broadcast, dsq_w_broadcast))

        # dx1 = (dxmu1 + dxmu2)
        # dmu = -1 * np.sum(dxmu1 + dxmu2, axis=0)
        # dx2 = 1. /N * np.ones((N,D)) * dmu
        # dx = dx1 + dx2
        dxmu2_mul = Multiply(Sum(ng.negative(dxmu2)), mean_2)
        dxmu2_div = Divide(dxmu2_mul, input_sum)
        dxmu2_div_w_broadcast = ng.PatternSkipOp(dxmu2_div,
                                                 (lambda op: isinstance(op, BroadcastOp)))
        dxmu2_div_plus_dxmu2 = Add(dxmu2_div_w_broadcast, dxmu2)

        dx1 = Add(dxmu1, dxmu2_div_plus_dxmu2)
        dxmu1_mul = Multiply(Sum(ng.negative(dxmu1)), mean_1)
        dxmu1_div = Divide(dxmu1_mul, Sum(flatten_tensor))
        dxmu1_div_w_broadcast = ng.PatternSkipOp(dxmu1_div,
                                                 (lambda op: isinstance(op, BroadcastOp)))
        dx = Add(dxmu1_div_w_broadcast, dx1)
        return dx