예제 #1
0
 def __init__(self, n_features=32):
     super(SpatialPropagationBlock, self).__init__()
     self.n_features = n_features
     # propagation layers
     self.Propagator_x1 = GateRecurrent2dnoind(True, False)
     self.Propagator_x2 = GateRecurrent2dnoind(True, True)
     self.Propagator_y1 = GateRecurrent2dnoind(False, False)
     self.Propagator_y2 = GateRecurrent2dnoind(False, True)
예제 #2
0
right->left: Propagator = GateRecurrent2dnoind(True,True)
top->bottom: Propagator = GateRecurrent2dnoind(False,False)
bottom->top: Propagator = GateRecurrent2dnoind(False,True)

X: any signal/feature map to be filtered
G1~G3: three coefficient maps (e.g., left-top, left-center, left-bottom)

Note:
1. G1~G3 constitute the affinity, they can be a bounch of output maps coming from any CNN, with the input of any useful known information (e.g., RGB images)
2. for any pixel i, |G1(i)| + |G2(i)| + |G3(i)| <= 1 is a sufficent condition for model stability (see paper)
"""
import torch
from torch.autograd import Variable
from pytorch_spn.modules.gaterecurrent2dnoind import GateRecurrent2dnoind

Propagator = GateRecurrent2dnoind(True, False)

X = Variable(torch.randn(1, 3, 10, 10))
G1 = Variable(torch.randn(1, 3, 10, 10))
G2 = Variable(torch.randn(1, 3, 10, 10))
G3 = Variable(torch.randn(1, 3, 10, 10))

sum_abs = G1.abs() + G2.abs() + G3.abs()
mask_need_norm = sum_abs.ge(1)
mask_need_norm = mask_need_norm.float()
G1_norm = torch.div(G1, sum_abs)
G2_norm = torch.div(G2, sum_abs)
G3_norm = torch.div(G3, sum_abs)

G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm
G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm
예제 #3
0
    def spnNet(self, out, y):
        Glr1 = out[:, 0:32]
        Glr2 = out[:, 32:64]
        Glr3 = out[:, 64:96]
        sum_abs = Glr1.abs() + Glr2.abs() + Glr3.abs()
        mask_need_norm = sum_abs.ge(1)
        mask_need_norm = mask_need_norm.float()
        Glr1_norm = torch.div(Glr1, sum_abs)
        Glr2_norm = torch.div(Glr2, sum_abs)
        Glr3_norm = torch.div(Glr3, sum_abs)

        Glr1 = torch.add(-mask_need_norm,
                         1) * Glr1 + mask_need_norm * Glr1_norm
        Glr2 = torch.add(-mask_need_norm,
                         1) * Glr2 + mask_need_norm * Glr2_norm
        Glr3 = torch.add(-mask_need_norm,
                         1) * Glr3 + mask_need_norm * Glr3_norm

        ylr = y.cuda()
        Glr1 = Glr1.cuda()
        Glr2 = Glr2.cuda()
        Glr3 = Glr3.cuda()
        Propagator = GateRecurrent2dnoind(True, False)

        ylr = Propagator.forward(ylr, Glr1, Glr2, Glr3)

        Grl1 = out[:, 96:128]
        Grl2 = out[:, 128:160]
        Grl3 = out[:, 160:192]
        sum_abs = Grl1.abs() + Grl2.abs() + Grl3.abs()
        mask_need_norm = sum_abs.ge(1)
        mask_need_norm = mask_need_norm.float()
        Grl1_norm = torch.div(Grl1, sum_abs)
        Grl2_norm = torch.div(Grl2, sum_abs)
        Grl3_norm = torch.div(Grl3, sum_abs)

        Grl1 = torch.add(-mask_need_norm,
                         1) * Grl1 + mask_need_norm * Grl1_norm
        Grl2 = torch.add(-mask_need_norm,
                         1) * Grl2 + mask_need_norm * Grl2_norm
        Grl3 = torch.add(-mask_need_norm,
                         1) * Grl3 + mask_need_norm * Grl3_norm

        yrl = y.cuda()
        Grl1 = Grl1.cuda()
        Grl2 = Grl2.cuda()
        Grl3 = Grl3.cuda()
        Propagator = GateRecurrent2dnoind(False, True)

        yrl = Propagator.forward(yrl, Grl1, Grl2, Grl3)

        Gdu1 = out[:, 192:224]
        Gdu2 = out[:, 224:256]
        Gdu3 = out[:, 256:288]
        sum_abs = Gdu1.abs() + Gdu2.abs() + Gdu3.abs()
        mask_need_norm = sum_abs.ge(1)
        mask_need_norm = mask_need_norm.float()
        Gdu1_norm = torch.div(Gdu1, sum_abs)
        Gdu2_norm = torch.div(Gdu2, sum_abs)
        Gdu3_norm = torch.div(Gdu3, sum_abs)

        Gdu1 = torch.add(-mask_need_norm,
                         1) * Gdu1 + mask_need_norm * Gdu1_norm
        Gdu2 = torch.add(-mask_need_norm,
                         1) * Gdu2 + mask_need_norm * Gdu2_norm
        Gdu3 = torch.add(-mask_need_norm,
                         1) * Gdu3 + mask_need_norm * Gdu3_norm

        ydu = y.cuda()
        Gdu1 = Gdu1.cuda()
        Gdu2 = Gdu2.cuda()
        Gdu3 = Gdu3.cuda()
        Propagator = GateRecurrent2dnoind(False, False)

        ydu = Propagator.forward(ydu, Gdu1, Gdu2, Gdu3)

        Gud1 = out[:, 288:320]
        Gud2 = out[:, 320:352]
        Gud3 = out[:, 352:384]
        sum_abs = Gud1.abs() + Gud2.abs() + Gud3.abs()
        mask_need_norm = sum_abs.ge(1)
        mask_need_norm = mask_need_norm.float()
        Gud1_norm = torch.div(Gud1, sum_abs)
        Gud2_norm = torch.div(Gud2, sum_abs)
        Gud3_norm = torch.div(Gud3, sum_abs)

        Gud1 = torch.add(-mask_need_norm,
                         1) * Gud1 + mask_need_norm * Gud1_norm
        Gud2 = torch.add(-mask_need_norm,
                         1) * Gud2 + mask_need_norm * Gud2_norm
        Gud3 = torch.add(-mask_need_norm,
                         1) * Gud3 + mask_need_norm * Gud3_norm

        yud = y.cuda()
        Gud1 = Gud1.cuda()
        Gud2 = Gud2.cuda()
        Gud3 = Gud3.cuda()
        Propagator = GateRecurrent2dnoind(True, True)

        yud = Propagator.forward(yud, Gud1, Gud2, Gud3)

        yout = torch.max(ylr, yrl)
        yout = torch.max(yout, yud)
        yout = torch.max(yout, ydu)
        #print("yout size", yout.size())
        #print("Gud1 size",Glr1.size())
        return yout
예제 #4
0
right->left: Propagator = GateRecurrent2dnoind(True,True)
top->bottom: Propagator = GateRecurrent2dnoind(False,False)
bottom->top: Propagator = GateRecurrent2dnoind(False,True)

X: any signal/feature map to be filtered
G1~G3: three coefficient maps (e.g., left-top, left-center, left-bottom)

Note:
1. G1~G3 constitute the affinity, they can be a bounch of output maps coming from any CNN, with the input of any useful known information (e.g., RGB images)
2. for any pixel i, |G1(i)| + |G2(i)| + |G3(i)| <= 1 is a sufficent condition for model stability (see paper)
"""
import torch
from torch.autograd import Variable
from pytorch_spn.modules.gaterecurrent2dnoind import GateRecurrent2dnoind

Propagator = GateRecurrent2dnoind(False, True)

X = Variable(torch.randn(1, 3, 10, 10))
G1 = Variable(torch.randn(1, 3, 10, 10))
G2 = Variable(torch.randn(1, 3, 10, 10))
G3 = Variable(torch.randn(1, 3, 10, 10))

sum_abs = G1.abs() + G2.abs() + G3.abs()
mask_need_norm = sum_abs.ge(1)
mask_need_norm = mask_need_norm.float()
G1_norm = torch.div(G1, sum_abs)
G2_norm = torch.div(G2, sum_abs)
G3_norm = torch.div(G3, sum_abs)

G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm
G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm
예제 #5
0
 def __init__(self):
     super(Refine, self).__init__()
     self.Propagator = GateRecurrent2dnoind(True, True)
예제 #6
0
    def spnNet(self, features, mask):
        '''
        spn refine the segmentation mask
        features:[N, 180, 112, 112]
        mask:[N, 15, 112, 112]
        '''

        # left->right:
        Propagator = GateRecurrent2dnoind(True, False)

        G1 = features[:, 0:15, :, :]
        G2 = features[:, 15:30, :, :]
        G3 = features[:, 30:45, :, :]

        sum_abs = G1.abs() + G2.abs() + G3.abs()
        mask_need_norm = sum_abs.ge(1)
        mask_need_norm = mask_need_norm.float()
        G1_norm = torch.div(G1, sum_abs)
        G2_norm = torch.div(G2, sum_abs)
        G3_norm = torch.div(G3, sum_abs)

        G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm
        G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm
        G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm
        G1[G1 != G1] = 0.001
        G2[G2 != G2] = 0.001
        G3[G3 != G3] = 0.001

        mask_l2r = Propagator.forward(mask, G1, G2, G3)

        # right->left:
        Propagator_r2l = GateRecurrent2dnoind(True, True)

        G1 = features[:, 45:60, :, :]
        G2 = features[:, 60:75, :, :]
        G3 = features[:, 75:90, :, :]

        sum_abs = G1.abs() + G2.abs() + G3.abs()

        mask_need_norm = sum_abs.ge(1)
        mask_need_norm = mask_need_norm.float()

        G1_norm = torch.div(G1, sum_abs)
        G2_norm = torch.div(G2, sum_abs)
        G3_norm = torch.div(G3, sum_abs)

        G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm
        G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm
        G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm
        G1[G1 != G1] = 0.001
        G2[G2 != G2] = 0.001
        G3[G3 != G3] = 0.001

        mask_r2l = Propagator_r2l.forward(mask, G1, G2, G3)

        # top->bottom:
        Propagator = GateRecurrent2dnoind(False, False)

        G1 = features[:, 90:105, :, :]
        G2 = features[:, 105:120, :, :]
        G3 = features[:, 120:135, :, :]

        sum_abs = G1.abs() + G2.abs() + G3.abs()
        mask_need_norm = sum_abs.ge(1)
        mask_need_norm = mask_need_norm.float()
        G1_norm = torch.div(G1, sum_abs)
        G2_norm = torch.div(G2, sum_abs)
        G3_norm = torch.div(G3, sum_abs)

        G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm
        G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm
        G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm
        G1[G1 != G1] = 0.001
        G2[G2 != G2] = 0.001
        G3[G3 != G3] = 0.001

        mask_t2b = Propagator.forward(mask, G1, G2, G3)

        # bottom->top:
        Propagator = GateRecurrent2dnoind(False, True)

        G1 = features[:, 135:150, :, :]
        G2 = features[:, 150:165, :, :]
        G3 = features[:, 165:180, :, :]

        sum_abs = G1.abs() + G2.abs() + G3.abs()
        mask_need_norm = sum_abs.ge(1)
        mask_need_norm = mask_need_norm.float()
        G1_norm = torch.div(G1, sum_abs)
        G2_norm = torch.div(G2, sum_abs)
        G3_norm = torch.div(G3, sum_abs)

        G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm
        G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm
        G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm
        G1[G1 != G1] = 0.001
        G2[G2 != G2] = 0.001
        G3[G3 != G3] = 0.001

        mask_b2t = Propagator.forward(mask, G1, G2, G3)

        # max
        mask1 = torch.max(mask_l2r, mask_r2l)
        mask2 = torch.max(mask_t2b, mask_b2t)
        result = torch.max(mask1, mask2)

        return result