Ejemplo n.º 1
0
 def forward(self, x, width_mult=1):
     module = self.body[0]
     res = common.SlimModule(x, module, width_mult)
     res = self.act(res)
     module = self.body[1]
     res = common.SlimModule(res, module, width_mult)
     res = self.caLayer(res, width_mult)
     res += x
     return res
Ejemplo n.º 2
0
    def forward(self, x, width_mult=1):
        y = self.avg_pool(x)

        module = getattr(self.conv_du, '0')
        y = common.SlimModule(y, module, width_mult)
        y = self.relu(y)

        module = getattr(self.conv_du, '1')
        y = common.SlimModule(y, module, width_mult)
        y = self.sigmoid(y)

        return x * y
Ejemplo n.º 3
0
    def forward(self, x, width_mult=1):
        x = self.sub_mean(x)
        weight = self.head_conv.weight
        n_feats = weight.shape[0]
        out_ch = int(n_feats * width_mult)
        weight = weight[:out_ch, :self.n_colors, :, :]
        bias = self.head_conv.bias[:out_ch]
        x = nn.functional.conv2d(x,
                                 weight,
                                 bias,
                                 stride=self.head_conv.stride,
                                 padding=self.head_conv.padding)

        res = x
        for module in self.body:
            res = module(res, width_mult)
        res = common.SlimModule(res, self.body_conv, width_mult)
        res += x

        x = self.upsampler(res, width_mult)
        weight = self.tail_conv.weight[:self.n_colors, :out_ch, :, :]
        bias = self.tail_conv.bias[:self.n_colors]
        x = nn.functional.conv2d(x,
                                 weight,
                                 bias,
                                 stride=self.tail_conv.stride,
                                 padding=self.tail_conv.padding)
        x = self.add_mean(x)

        return x
Ejemplo n.º 4
0
 def forward(self, x, width_mult):
     res = x
     for module in self.body:
         res = module(res, width_mult)
     res = common.SlimModule(res, self.conv, width_mult)
     res += x
     return res
Ejemplo n.º 5
0
 def forward(self, x, width_mult):
     out = common.SlimModule(x, self.conv, width_mult)
     out = self.act(out)
     return out