コード例 #1
0
    def forward(self, c3, c4, c5):
        b = c3.size(0)

        p51 = self.conv51(c5)
        p52 = self.conv52(p51)
        p5 = self.pred5(p52)

        p41 = upsample_concat(self.lat5(p51), c4)
        p42 = self.conv41(p41)
        p43 = self.conv42(p42)
        p4 = self.pred4(p43)

        p31 = upsample_concat(self.lat4(p42), c3)
        p32 = self.conv31(p31)
        p33 = self.conv32(p32)
        p3 = self.pred3(p33)

        preds = [p3, p4, p5]

        loc_preds = []
        obj_preds = []
        cls_preds = []
        log_var_preds = []
        for p in preds:
            p = p.permute(0, 3, 2,
                          1).contiguous().view(b, -1, 9 + self.num_classes)
            loc_preds.append(p[..., :4])
            log_var_preds.append(p[..., 4:8])
            obj_preds.append(p[..., 8])
            cls_preds.append(p[..., 9:])
        loc_p = torch.cat(loc_preds, dim=1)
        obj_p = torch.cat(obj_preds, dim=1)
        cls_p = torch.cat(cls_preds, dim=1)
        log_var_p = torch.cat(log_var_preds, dim=1)
        return loc_p, obj_p, cls_p, log_var_p
コード例 #2
0
ファイル: enhance.py プロジェクト: sbl1996/pytorch-hrvvi-ext
 def forward(self, c, p):
     if self.aggregate == 'cat':
         p = upsample_concat(p, self.lat(c))
     else:
         p = upsample_add(p, self.lat(c))
     p = self.conv(p)
     return p
コード例 #3
0
ファイル: enhance.py プロジェクト: sbl1996/pytorch-hrvvi-ext
 def forward(self, *cs):
     ps = []
     p1 = self.convs[0](cs[-1])
     p2 = self.outs[0](p1)
     ps.append(p2)
     for lat, conv, out, c in zip(self.lats, self.convs[1:], self.outs[1:],
                                  reversed(cs[:-1])):
         c = upsample_concat(lat(p1), c)
         p1 = conv(c)
         p2 = out(p1)
         ps.append(p2)
     return tuple(reversed(ps))
コード例 #4
0
 def forward(self, x1, x2):
     x1 = self.conv1(x1)
     x2 = self.conv2(x2)
     return upsample_concat(x2, x1)