def __init__(self, n_features=32): super(SpatialPropagationBlock, self).__init__() self.n_features = n_features # propagation layers self.Propagator_x1 = GateRecurrent2dnoind(True, False) self.Propagator_x2 = GateRecurrent2dnoind(True, True) self.Propagator_y1 = GateRecurrent2dnoind(False, False) self.Propagator_y2 = GateRecurrent2dnoind(False, True)
right->left: Propagator = GateRecurrent2dnoind(True,True) top->bottom: Propagator = GateRecurrent2dnoind(False,False) bottom->top: Propagator = GateRecurrent2dnoind(False,True) X: any signal/feature map to be filtered G1~G3: three coefficient maps (e.g., left-top, left-center, left-bottom) Note: 1. G1~G3 constitute the affinity, they can be a bounch of output maps coming from any CNN, with the input of any useful known information (e.g., RGB images) 2. for any pixel i, |G1(i)| + |G2(i)| + |G3(i)| <= 1 is a sufficent condition for model stability (see paper) """ import torch from torch.autograd import Variable from pytorch_spn.modules.gaterecurrent2dnoind import GateRecurrent2dnoind Propagator = GateRecurrent2dnoind(True, False) X = Variable(torch.randn(1, 3, 10, 10)) G1 = Variable(torch.randn(1, 3, 10, 10)) G2 = Variable(torch.randn(1, 3, 10, 10)) G3 = Variable(torch.randn(1, 3, 10, 10)) sum_abs = G1.abs() + G2.abs() + G3.abs() mask_need_norm = sum_abs.ge(1) mask_need_norm = mask_need_norm.float() G1_norm = torch.div(G1, sum_abs) G2_norm = torch.div(G2, sum_abs) G3_norm = torch.div(G3, sum_abs) G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm
def spnNet(self, out, y): Glr1 = out[:, 0:32] Glr2 = out[:, 32:64] Glr3 = out[:, 64:96] sum_abs = Glr1.abs() + Glr2.abs() + Glr3.abs() mask_need_norm = sum_abs.ge(1) mask_need_norm = mask_need_norm.float() Glr1_norm = torch.div(Glr1, sum_abs) Glr2_norm = torch.div(Glr2, sum_abs) Glr3_norm = torch.div(Glr3, sum_abs) Glr1 = torch.add(-mask_need_norm, 1) * Glr1 + mask_need_norm * Glr1_norm Glr2 = torch.add(-mask_need_norm, 1) * Glr2 + mask_need_norm * Glr2_norm Glr3 = torch.add(-mask_need_norm, 1) * Glr3 + mask_need_norm * Glr3_norm ylr = y.cuda() Glr1 = Glr1.cuda() Glr2 = Glr2.cuda() Glr3 = Glr3.cuda() Propagator = GateRecurrent2dnoind(True, False) ylr = Propagator.forward(ylr, Glr1, Glr2, Glr3) Grl1 = out[:, 96:128] Grl2 = out[:, 128:160] Grl3 = out[:, 160:192] sum_abs = Grl1.abs() + Grl2.abs() + Grl3.abs() mask_need_norm = sum_abs.ge(1) mask_need_norm = mask_need_norm.float() Grl1_norm = torch.div(Grl1, sum_abs) Grl2_norm = torch.div(Grl2, sum_abs) Grl3_norm = torch.div(Grl3, sum_abs) Grl1 = torch.add(-mask_need_norm, 1) * Grl1 + mask_need_norm * Grl1_norm Grl2 = torch.add(-mask_need_norm, 1) * Grl2 + mask_need_norm * Grl2_norm Grl3 = torch.add(-mask_need_norm, 1) * Grl3 + mask_need_norm * Grl3_norm yrl = y.cuda() Grl1 = Grl1.cuda() Grl2 = Grl2.cuda() Grl3 = Grl3.cuda() Propagator = GateRecurrent2dnoind(False, True) yrl = Propagator.forward(yrl, Grl1, Grl2, Grl3) Gdu1 = out[:, 192:224] Gdu2 = out[:, 224:256] Gdu3 = out[:, 256:288] sum_abs = Gdu1.abs() + Gdu2.abs() + Gdu3.abs() mask_need_norm = sum_abs.ge(1) mask_need_norm = mask_need_norm.float() Gdu1_norm = torch.div(Gdu1, sum_abs) Gdu2_norm = torch.div(Gdu2, sum_abs) Gdu3_norm = torch.div(Gdu3, sum_abs) Gdu1 = torch.add(-mask_need_norm, 1) * Gdu1 + mask_need_norm * Gdu1_norm Gdu2 = torch.add(-mask_need_norm, 1) * Gdu2 + mask_need_norm * Gdu2_norm Gdu3 = torch.add(-mask_need_norm, 1) * Gdu3 + mask_need_norm * Gdu3_norm ydu = y.cuda() Gdu1 = Gdu1.cuda() Gdu2 = Gdu2.cuda() Gdu3 = Gdu3.cuda() Propagator = GateRecurrent2dnoind(False, False) ydu = Propagator.forward(ydu, Gdu1, Gdu2, Gdu3) Gud1 = out[:, 288:320] Gud2 = out[:, 320:352] Gud3 = out[:, 352:384] sum_abs = Gud1.abs() + Gud2.abs() + Gud3.abs() mask_need_norm = sum_abs.ge(1) mask_need_norm = mask_need_norm.float() Gud1_norm = torch.div(Gud1, sum_abs) Gud2_norm = torch.div(Gud2, sum_abs) Gud3_norm = torch.div(Gud3, sum_abs) Gud1 = torch.add(-mask_need_norm, 1) * Gud1 + mask_need_norm * Gud1_norm Gud2 = torch.add(-mask_need_norm, 1) * Gud2 + mask_need_norm * Gud2_norm Gud3 = torch.add(-mask_need_norm, 1) * Gud3 + mask_need_norm * Gud3_norm yud = y.cuda() Gud1 = Gud1.cuda() Gud2 = Gud2.cuda() Gud3 = Gud3.cuda() Propagator = GateRecurrent2dnoind(True, True) yud = Propagator.forward(yud, Gud1, Gud2, Gud3) yout = torch.max(ylr, yrl) yout = torch.max(yout, yud) yout = torch.max(yout, ydu) #print("yout size", yout.size()) #print("Gud1 size",Glr1.size()) return yout
right->left: Propagator = GateRecurrent2dnoind(True,True) top->bottom: Propagator = GateRecurrent2dnoind(False,False) bottom->top: Propagator = GateRecurrent2dnoind(False,True) X: any signal/feature map to be filtered G1~G3: three coefficient maps (e.g., left-top, left-center, left-bottom) Note: 1. G1~G3 constitute the affinity, they can be a bounch of output maps coming from any CNN, with the input of any useful known information (e.g., RGB images) 2. for any pixel i, |G1(i)| + |G2(i)| + |G3(i)| <= 1 is a sufficent condition for model stability (see paper) """ import torch from torch.autograd import Variable from pytorch_spn.modules.gaterecurrent2dnoind import GateRecurrent2dnoind Propagator = GateRecurrent2dnoind(False, True) X = Variable(torch.randn(1, 3, 10, 10)) G1 = Variable(torch.randn(1, 3, 10, 10)) G2 = Variable(torch.randn(1, 3, 10, 10)) G3 = Variable(torch.randn(1, 3, 10, 10)) sum_abs = G1.abs() + G2.abs() + G3.abs() mask_need_norm = sum_abs.ge(1) mask_need_norm = mask_need_norm.float() G1_norm = torch.div(G1, sum_abs) G2_norm = torch.div(G2, sum_abs) G3_norm = torch.div(G3, sum_abs) G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm
def __init__(self): super(Refine, self).__init__() self.Propagator = GateRecurrent2dnoind(True, True)
def spnNet(self, features, mask): ''' spn refine the segmentation mask features:[N, 180, 112, 112] mask:[N, 15, 112, 112] ''' # left->right: Propagator = GateRecurrent2dnoind(True, False) G1 = features[:, 0:15, :, :] G2 = features[:, 15:30, :, :] G3 = features[:, 30:45, :, :] sum_abs = G1.abs() + G2.abs() + G3.abs() mask_need_norm = sum_abs.ge(1) mask_need_norm = mask_need_norm.float() G1_norm = torch.div(G1, sum_abs) G2_norm = torch.div(G2, sum_abs) G3_norm = torch.div(G3, sum_abs) G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm G1[G1 != G1] = 0.001 G2[G2 != G2] = 0.001 G3[G3 != G3] = 0.001 mask_l2r = Propagator.forward(mask, G1, G2, G3) # right->left: Propagator_r2l = GateRecurrent2dnoind(True, True) G1 = features[:, 45:60, :, :] G2 = features[:, 60:75, :, :] G3 = features[:, 75:90, :, :] sum_abs = G1.abs() + G2.abs() + G3.abs() mask_need_norm = sum_abs.ge(1) mask_need_norm = mask_need_norm.float() G1_norm = torch.div(G1, sum_abs) G2_norm = torch.div(G2, sum_abs) G3_norm = torch.div(G3, sum_abs) G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm G1[G1 != G1] = 0.001 G2[G2 != G2] = 0.001 G3[G3 != G3] = 0.001 mask_r2l = Propagator_r2l.forward(mask, G1, G2, G3) # top->bottom: Propagator = GateRecurrent2dnoind(False, False) G1 = features[:, 90:105, :, :] G2 = features[:, 105:120, :, :] G3 = features[:, 120:135, :, :] sum_abs = G1.abs() + G2.abs() + G3.abs() mask_need_norm = sum_abs.ge(1) mask_need_norm = mask_need_norm.float() G1_norm = torch.div(G1, sum_abs) G2_norm = torch.div(G2, sum_abs) G3_norm = torch.div(G3, sum_abs) G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm G1[G1 != G1] = 0.001 G2[G2 != G2] = 0.001 G3[G3 != G3] = 0.001 mask_t2b = Propagator.forward(mask, G1, G2, G3) # bottom->top: Propagator = GateRecurrent2dnoind(False, True) G1 = features[:, 135:150, :, :] G2 = features[:, 150:165, :, :] G3 = features[:, 165:180, :, :] sum_abs = G1.abs() + G2.abs() + G3.abs() mask_need_norm = sum_abs.ge(1) mask_need_norm = mask_need_norm.float() G1_norm = torch.div(G1, sum_abs) G2_norm = torch.div(G2, sum_abs) G3_norm = torch.div(G3, sum_abs) G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm G1[G1 != G1] = 0.001 G2[G2 != G2] = 0.001 G3[G3 != G3] = 0.001 mask_b2t = Propagator.forward(mask, G1, G2, G3) # max mask1 = torch.max(mask_l2r, mask_r2l) mask2 = torch.max(mask_t2b, mask_b2t) result = torch.max(mask1, mask2) return result