def forward(self, input): x, meta = input identity = x if self.downsample is not None: identity = self.downsample(x) if not self.sparse: out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += identity else: assert meta is not None m = self.masker(x, meta) mask_dilate, mask = m['dilate'], m['std'] x = dynconv.conv3x3(self.conv1, x, None, mask_dilate) x = dynconv.bn_relu(self.bn1, self.relu, x, mask_dilate) x = dynconv.conv3x3(self.conv2, x, mask_dilate, mask) x = dynconv.bn_relu(self.bn2, None, x, mask) out = identity + dynconv.apply_mask(x, mask) out = self.relu(out) return out, meta
def forward(self, v): x, meta = v if not self.sparse: out = self.conv(x) if self.identity: out += x return out, meta else: assert self.identity and self.expand_ratio != 1 m, meta = self.masker(x, meta) mask, mask_dilate = m['std'], m['dilate'] fast_inference = not self.training out = x.clone( ) # clone should not be needed, but otherwise seems to be bugged if fast_inference: x = dynconv.gather(x, mask_dilate) x = dynconv.conv1x1(self.conv[0], x, mask_dilate, fast=fast_inference) x = dynconv.bn_relu(self.conv[1], self.conv[2], x, mask_dilate, fast=fast_inference) x = dynconv.conv3x3_dw(self.conv[3], x, mask_dilate, mask, fast=fast_inference) x = dynconv.bn_relu(self.conv[4], self.conv[5], x, mask, fast=fast_inference) x = dynconv.conv1x1(self.conv[6], x, mask, fast=fast_inference) x = dynconv.bn_relu(self.conv[7], None, x, mask, fast=fast_inference) if fast_inference: out = dynconv.scatter(x, out, mask, sum_out=True) else: out = out + dynconv.apply_mask(x, mask) return out, meta