Esempio n. 1
0
    def forward(self, input):
        x, meta = input
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        if not self.sparse:
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
            out = self.conv2(out)
            out = self.bn2(out)
            out += identity
        else:
            assert meta is not None
            m = self.masker(x, meta)
            mask_dilate, mask = m['dilate'], m['std']

            x = dynconv.conv3x3(self.conv1, x, None, mask_dilate)
            x = dynconv.bn_relu(self.bn1, self.relu, x, mask_dilate)
            x = dynconv.conv3x3(self.conv2, x, mask_dilate, mask)
            x = dynconv.bn_relu(self.bn2, None, x, mask)
            out = identity + dynconv.apply_mask(x, mask)

        out = self.relu(out)
        return out, meta
Esempio n. 2
0
    def forward(self, v):
        x, meta = v
        if not self.sparse:
            out = self.conv(x)
            if self.identity:
                out += x
            return out, meta
        else:
            assert self.identity and self.expand_ratio != 1
            m, meta = self.masker(x, meta)
            mask, mask_dilate = m['std'], m['dilate']

            fast_inference = not self.training

            out = x.clone(
            )  # clone should not be needed, but otherwise seems to be bugged
            if fast_inference:
                x = dynconv.gather(x, mask_dilate)

            x = dynconv.conv1x1(self.conv[0],
                                x,
                                mask_dilate,
                                fast=fast_inference)
            x = dynconv.bn_relu(self.conv[1],
                                self.conv[2],
                                x,
                                mask_dilate,
                                fast=fast_inference)
            x = dynconv.conv3x3_dw(self.conv[3],
                                   x,
                                   mask_dilate,
                                   mask,
                                   fast=fast_inference)
            x = dynconv.bn_relu(self.conv[4],
                                self.conv[5],
                                x,
                                mask,
                                fast=fast_inference)
            x = dynconv.conv1x1(self.conv[6], x, mask, fast=fast_inference)
            x = dynconv.bn_relu(self.conv[7],
                                None,
                                x,
                                mask,
                                fast=fast_inference)

            if fast_inference:
                out = dynconv.scatter(x, out, mask, sum_out=True)
            else:
                out = out + dynconv.apply_mask(x, mask)
            return out, meta