Exemple #1
0
    def _get_output_shape(self):
        """Get output shape"""
        shape = self.inputs[0].shape
        multiples = self.attrs["multiples"]

        shape = list(shape)
        multiples = list(multiples)
        diff_len = len(multiples) - len(shape)
        if diff_len < 0:
            raise GKException(
                "Dimensions of multiples{} < dimensions of input{} in Tile".
                format(multiples, shape))
        if diff_len > 0:
            for _ in range(diff_len):
                shape.insert(0, 1)

        output_shape = []

        for sh, mul in list(zip(shape, multiples)):
            if sh != 1 and mul != 1:
                raise GKException(
                    "Tile op in expander only Support Automatic Broadcast!")
            dim = sh * mul
            output_shape.append(dim)
        return output_shape
Exemple #2
0
 def _check_format(obj):
     inp_formats = [inp['format'] for inp in obj.inputs]
     for formats in getattr(obj, format_list_name):
         if len(formats) != len(inp_formats):
             raise GKException("length of registered format doesn't match the input of {}".format(obj.name))
         if all([fmt == inp for fmt, inp in zip(formats, inp_formats)]):
             return
     raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name))
Exemple #3
0
 def _check_format(obj):
     inp_formats = [inp['format'] for inp in obj.inputs]
     if all([fmt == inp_formats[0] for fmt in inp_formats[1:]]):
         return
     raise GKException(
         "[check_all_formats_same] unmatched formats ({}) for op {}".
         format(','.join(inp_formats), obj.name))
Exemple #4
0
    def _check(self):
        type_0 = self.inputs[0]['data_type']
        type_1 = self.inputs[1]['data_type']
        if type_0 != "float16" or type_1 != "float16":
            raise GKException("inputs type should be float16, but got {} and {}".format(type_0, type_1))

        groups = self.attrs['groups']
        group = self.attrs['group']
        if groups != 1 or group != 1:
            raise GKException("groups and group should be both 1, but got {} and {}.".format(groups, group))

        dilation = self.attrs['dilation']
        check_nd(dilation, 4)
        if dilation != [1, 1, 1, 1]:
            raise GKException("dilation should be all 1, but got {}".format(dilation))

        pad_list = self.attrs['pad_list']
        pad_mode = self.attrs['pad_mode']
        check_nd(pad_list, 4)
        self.has_pad = conv_had_pad(pad_list, pad_mode)

        shape_0 = self.inputs[0]['shape']
        shape_1 = self.inputs[1]['shape']
        stride = self.attrs['stride']
        check_nd(shape_0, 4)
        check_nd(shape_1, 4)
        check_nd(stride, 4)
        n0, h0, w0, c0 = shape_0
        n1, h1, w1, c1 = shape_1
        if n0 < N0_CHANNEL_ALIGN:
            raise GKException("N({}) channel of first input should >= {}".format(n0, N0_CHANNEL_ALIGN))
        if n1 < N1_CHANNEL_ALIGN:
            raise GKException("N({}) channel of second input should >= {}".format(n1, N1_CHANNEL_ALIGN))
        if c0 != c1 or c0 < C_CHANNEL_ALIGN:
            raise GKException("C channel of inputs({}, {}) should be same and >= {}".format(c0, c1, C_CHANNEL_ALIGN))
        # n0 pad
        n0 = ((n0 + N0_CHANNEL_ALIGN - 1) // N0_CHANNEL_ALIGN) * N0_CHANNEL_ALIGN
        # h0, w0 pad
        if self.has_pad:
            h0 = h0 + pad_list[0] + pad_list[1]
            w0 = w0 + pad_list[2] + pad_list[3]
        # c0, c1 pad
        c0 = ((c0 + C_CHANNEL_ALIGN - 1) // C_CHANNEL_ALIGN) * C_CHANNEL_ALIGN
        c1 = c0
        # n1 pad
        n1 = ((n1 + N1_CHANNEL_ALIGN - 1) // N1_CHANNEL_ALIGN) * N1_CHANNEL_ALIGN

        # check if can optimize to matmul
        self.m, self.n, self.k = n0 * h0 * w0, n1, c1
        self.can_optimize_to_matmul = self._optimize_to_matmul()

        out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1
        if not self.can_optimize_to_matmul and n0 * out_h * out_w % OUT_NHW_ALIGN != 0:
            raise GKException("N({}) * H({}) * W({}) of Conv2d output should be multiplies of {}"
                              .format(n0, out_h, out_w, OUT_NHW_ALIGN))
        self.shape_0_pad = [n0, h0, w0, c0]
        self.shape_1_pad = [n1, h1, w1, c1]
Exemple #5
0
 def _check_attr(obj):
     for a in args:
         if a not in obj.attrs:
             raise GKException("attr '{}' does not exist.".format(a))
Exemple #6
0
 def _check(self):
     if not self.attrs.get('grad_x', True) and not self.attrs.get(
             'grad_y', True):
         raise GKException("both grad_x and grad_y are False.")
     return super()._check()