Ejemplo n.º 1
0
    def forward(self, x, gts=None):

        x_size = x.size()  # 800
        x0 = self.layer0(x)  # 400
        x1 = self.layer1(x0)  # 400
        x2 = self.layer2(x1)  # 100
        x3 = self.layer3(x2)  # 100
        x4 = self.layer4(x3)  # 100
        dec0_up = self.fcn_head(x4)

        if self.skip == 'm1':
            dec0_fine = self.bot_fine(x1)
            dec0_up = Upsample(dec0_up, x1.size()[2:])
        else:
            dec0_fine = self.bot_fine(x2)
            dec0_up = Upsample(dec0_up, x2.size()[2:])

        dec0 = [dec0_fine, dec0_up]
        dec0 = torch.cat(dec0, 1)
        dec1 = self.final(dec0)
        main_out = Upsample(dec1, x_size[2:])

        if self.training:
            return self.criterion(main_out, gts)

        return main_out
Ejemplo n.º 2
0
    def forward(self, inp, gts=None):

        x_size = inp.size()
        x = self.mod1(inp)
        m2 = self.mod2(self.pool2(x))
        x = self.mod3(self.pool3(m2))
        x = self.mod4(x)
        x = self.mod5(x)
        x = self.mod6(x)
        x = self.mod7(x)
        x = self.aspp(x)
        dec0_up = self.bot_aspp(x)

        dec0_fine = self.bot_fine(m2)
        dec0_up = Upsample(dec0_up, m2.size()[2:])
        dec0 = [dec0_fine, dec0_up]
        dec0 = torch.cat(dec0, 1)

        dec1 = self.final(dec0)
        out = Upsample(dec1, x_size[2:])

        if self.training:
            return self.criterion(out, gts)

        return out
 def forward(self, x, gts=None):
     x_size = x.size()
     x0 = self.layer0(x)
     x1 = self.layer1(x0)
     x2 = self.layer2(x1)
     x3 = self.layer3(x2)
     x4 = self.layer4(x3)
     x = self.head([x1, x2, x3, x4])
     main_out = Upsample(x[0], x_size[2:])
     edge_preds = [Upsample(edge_pred, x_size[2:]) for edge_pred in x[1]]
     if self.training:
         if not self.fpn_dsn:
             return self.criterion([main_out, edge_preds], gts)
         return self.criterion(x, gts)
     return main_out
Ejemplo n.º 4
0
    def forward(self, x, gts=None):
        x_size = x.size()[2:]  # 800
        x0 = self.layer0(x)  # 400
        x1 = self.layer1(x0)  # 400
        x2 = self.layer2(x1)  # 100
        x3 = self.layer3(x2)  # 100
        x4 = self.layer4(x3)  # 100
        aux_out = self.aux_layer(x3)
        main_out = self.head(x4)

        aux_out = Upsample(aux_out, size=x_size)
        main_out = Upsample(main_out, size=x_size)

        if self.training:
            return self.criterion([main_out, aux_out], gts)

        return main_out
Ejemplo n.º 5
0
 def forward(self, x, gts=None):
     x_size = x.size()  # 800
     x2, x3, x4 = self.backbone(x)
     x = self.head([x2, x3, x4])
     main_out = Upsample(x[0], x_size[2:])
     if self.training:
         if not self.fpn_dsn:
             return self.criterion(main_out, gts)
         else:
             return self.criterion(x, gts)
     return main_out
Ejemplo n.º 6
0
    def forward(self, x):
        x_size = x.size()

        img_features = self.img_pooling(x)
        img_features = self.img_conv(img_features)
        img_features = Upsample(img_features, x_size[2:])
        out = img_features

        for f in self.features:
            y = f(x)
            out = torch.cat((out, y), 1)
        return out
Ejemplo n.º 7
0
 def forward(self, x, gts=None):
     x_size = x.size()  # 800
     x0 = self.layer0(x)  # 400
     x1 = self.layer1(x0)  # 400
     x2 = self.layer2(x1)  # 100
     x3 = self.layer3(x2)  # 100
     x4 = self.layer4(x3)  # 100
     x = self.head([x1, x2, x3, x4])
     main_out = Upsample(x[0], x_size[2:])
     if self.training:
         if not self.fpn_dsn:
             return self.criterion(main_out, gts)
         return self.criterion(x, gts)
     return main_out
Ejemplo n.º 8
0
    def forward(self, inp, gts=None):

        x_size = inp.size()
        x = self.mod1(inp)
        m2 = self.mod2(self.pool2(x))
        fine_size = m2.size()
        x = self.mod3(self.pool3(m2))
        x = self.mod4(x)
        x = self.mod5(x)
        x = self.mod6(x)
        x = self.mod7(x)
        x = self.aspp(x)
        aspp = self.bot_aspp(x)

        seg_body, seg_edge = self.squeeze_body_edge(aspp)

        # may add canny edge
        # canny_edge = self.edge_canny(inp, x_size)
        # add low-level feature
        dec0_fine = self.bot_fine(m2)
        seg_edge = self.edge_fusion(
            torch.cat([Upsample(seg_edge, fine_size[2:]), dec0_fine], dim=1))
        seg_edge_out = self.edge_out(seg_edge)

        seg_out = seg_edge + Upsample(seg_body, fine_size[2:])
        aspp = Upsample(aspp, fine_size[2:])

        seg_out = torch.cat([aspp, seg_out], dim=1)
        seg_final = self.final_seg(seg_out)

        seg_edge_out = Upsample(seg_edge_out, x_size[2:])
        seg_edge_out = self.sigmoid_edge(seg_edge_out)

        seg_final_out = Upsample(seg_final, x_size[2:])

        seg_body_out = Upsample(self.dsn_seg_body(seg_body), x_size[2:])

        # dec0_up = Upsample(dec0_up, m2.size()[2:])
        # dec0 = [dec0_fine, dec0_up]
        # dec0 = torch.cat(dec0, 1)
        # dec1 = self.final(dec0)
        # seg_out = Upsample(dec1, x_size[2:])

        if self.training:
            return self.criterion((seg_final_out, seg_body_out, seg_edge_out),
                                  gts)

        return seg_final_out
Ejemplo n.º 9
0
    def forward(self, x, gts=None):

        x_size = x.size()  # 800
        x0 = self.layer0(x)  # 400
        x1 = self.layer1(x0)  # 400
        fine_size = x1.size()
        x2 = self.layer2(x1)  # 100
        x3 = self.layer3(x2)  # 100
        x4 = self.layer4(x3)  # 100
        xp = self.aspp(x4)

        aspp = self.bot_aspp(xp)

        seg_body, seg_edge = self.squeeze_body_edge(aspp)

        if self.skip == 'm1':
            # use default low-level feature
            dec0_fine = self.bot_fine(x1)
        else:
            dec0_fine = self.bot_fine(x2)

        seg_edge = self.edge_fusion(
            torch.cat([Upsample(seg_edge, fine_size[2:]), dec0_fine], dim=1))
        seg_edge_out = self.edge_out(seg_edge)

        seg_out = seg_edge + Upsample(seg_body, fine_size[2:])
        aspp = Upsample(aspp, fine_size[2:])

        seg_out = torch.cat([aspp, seg_out], dim=1)
        seg_final = self.final_seg(seg_out)

        seg_edge_out = Upsample(seg_edge_out, x_size[2:])
        seg_edge_out = self.sigmoid_edge(seg_edge_out)

        seg_final_out = Upsample(seg_final, x_size[2:])

        seg_body_out = Upsample(self.dsn_seg_body(seg_body), x_size[2:])

        if self.training:
            return self.criterion((seg_final_out, seg_body_out, seg_edge_out),
                                  gts)

        return seg_final_out