コード例 #1
0
        def __init__(self, anchor_mode='on'):
            super(Net, self).__init__()

            self.conv = nn.Conv2d(6, 9, 3)
            self.conv2 = nn.Conv2d(9, 12, 3)
            self.linear = nn.Linear(28, 20)
            self.linear2 = nn.Linear(20, 15)
            self.gn = nn.GroupNorm(3, 12)  # to check multiple nodes
            self.linear3 = nn.Linear(15, 10)

            # to check output values (not reduce node number)
            nn.init.constant_(self.conv.weight, 0.1)
            nn.init.constant_(self.conv.bias, 0.1)
            nn.init.constant_(self.conv2.weight, 0.1)
            nn.init.constant_(self.conv2.bias, 0.1)
            nn.init.constant_(self.linear.weight, 0.1)
            nn.init.constant_(self.linear.bias, 0.1)
            nn.init.constant_(self.linear2.weight, 0.1)
            nn.init.constant_(self.linear2.bias, 0.1)
            nn.init.constant_(self.linear3.weight, 0.1)
            nn.init.constant_(self.linear3.bias, 0.1)

            if anchor_mode == 'on':
                self.anchor1 = scoped_anchor(aaa='a', bbb=['b', 'c'])
                self.anchor2 = scoped_anchor(ccc=[1, 2])
            elif anchor_mode == 'no_param':
                self.anchor1 = scoped_anchor()
                self.anchor2 = scoped_anchor()
            else:
                self.anchor1 = suppress()
                self.anchor2 = suppress()
コード例 #2
0
 def set_anchor(self):
     # required to setup in forwarding phase
     if self.anchor_mode == 'on':
         self.anchor1 = scoped_anchor(aaa='a', bbb=['b', 'c'])
         self.anchor2 = scoped_anchor(ccc=[1, 2])
     elif self.anchor_mode == 'no_param':
         self.anchor1 = scoped_anchor()
         self.anchor2 = scoped_anchor()
     else:
         self.anchor1 = suppress()
         self.anchor2 = suppress()
コード例 #3
0
 def forward(self, *xs):
     with scoped_anchor():
         xs = self.id(xs)
         h = torch.cat(xs, 1)
         h = h.t()
         hs = h.split(1)
         hs = self.id(hs)  # to check internal dummy anchor
         hs += (hs[0], hs[1])
         h = torch.cat(hs, 0)
         hs = h.split(1)
         return self.id(hs)
コード例 #4
0
 def forward(self, x):
     h = self.conv(x)
     with scoped_anchor(aaa='a'):
         h = self.linear(h)
     return h