def __init__(self, anchor_mode='on'): super(Net, self).__init__() self.conv = nn.Conv2d(6, 9, 3) self.conv2 = nn.Conv2d(9, 12, 3) self.linear = nn.Linear(28, 20) self.linear2 = nn.Linear(20, 15) self.gn = nn.GroupNorm(3, 12) # to check multiple nodes self.linear3 = nn.Linear(15, 10) # to check output values (not reduce node number) nn.init.constant_(self.conv.weight, 0.1) nn.init.constant_(self.conv.bias, 0.1) nn.init.constant_(self.conv2.weight, 0.1) nn.init.constant_(self.conv2.bias, 0.1) nn.init.constant_(self.linear.weight, 0.1) nn.init.constant_(self.linear.bias, 0.1) nn.init.constant_(self.linear2.weight, 0.1) nn.init.constant_(self.linear2.bias, 0.1) nn.init.constant_(self.linear3.weight, 0.1) nn.init.constant_(self.linear3.bias, 0.1) if anchor_mode == 'on': self.anchor1 = scoped_anchor(aaa='a', bbb=['b', 'c']) self.anchor2 = scoped_anchor(ccc=[1, 2]) elif anchor_mode == 'no_param': self.anchor1 = scoped_anchor() self.anchor2 = scoped_anchor() else: self.anchor1 = suppress() self.anchor2 = suppress()
def set_anchor(self): # required to setup in forwarding phase if self.anchor_mode == 'on': self.anchor1 = scoped_anchor(aaa='a', bbb=['b', 'c']) self.anchor2 = scoped_anchor(ccc=[1, 2]) elif self.anchor_mode == 'no_param': self.anchor1 = scoped_anchor() self.anchor2 = scoped_anchor() else: self.anchor1 = suppress() self.anchor2 = suppress()
def forward(self, *xs): with scoped_anchor(): xs = self.id(xs) h = torch.cat(xs, 1) h = h.t() hs = h.split(1) hs = self.id(hs) # to check internal dummy anchor hs += (hs[0], hs[1]) h = torch.cat(hs, 0) hs = h.split(1) return self.id(hs)
def forward(self, x): h = self.conv(x) with scoped_anchor(aaa='a'): h = self.linear(h) return h