Exemplo n.º 1
0
 def inverse(self, out_bij):
     """ irevnet inverse """
     out = split(out_bij)
     for i in range(len(self.stack)):
         out = self.stack[-1 - i].inverse(out)
     out = merge(out[0], out[1])
     x = self.init_psi.inverse(out)
     return x
Exemplo n.º 2
0
 def forward(self, x):
     """ irevnet forward """
     n = self.in_ch // 2
     if self.init_ds != 0:
         x = self.init_psi.forward(x)
     out = (x[:, :n, :, :], x[:, n:, :, :])
     for block in self.stack:
         out = block.forward(out)
     out_bij = merge(out[0], out[1])
     out = F.relu(self.bn1(out_bij))
     out = F.avg_pool2d(out, self.ds)
     out = out.view(out.size(0), -1)
     out = self.linear(out)
     return out, out_bij
Exemplo n.º 3
0
 def inverse(self, x):
     """ bijective or injecitve block inverse """
     x2, y1 = x[0], x[1]
     if self.stride == 2:
         x2 = self.psi.inverse(x2)
     Fx2 = -self.bottleneck_block(x2)
     x1 = Fx2 + y1
     if self.stride == 2:
         x1 = self.psi.inverse(x1)
     if self.pad != 0 and self.stride == 1:
         x = merge(x1, x2)
         x = self.inj_pad.inverse(x)
         x1, x2 = split(x)
         x = (x1, x2)
     else:
         x = (x1, x2)
     return x
Exemplo n.º 4
0
    def forward(self, x):

        n = self.inchannel // 2

        out = (x[:, :n, :, :], x[:, n:, :, :])

        for i in range(len(self.block_list)):
            block = self.block_list[i]
            out = block.forward(out)

        out_bij = merge(out[0], out[1])

        # out = stratx + out_bij
        # if self.pooling:
        #     out =  self.psi.forward(out)
        #

        return out_bij
Exemplo n.º 5
0
 def inverse(self, x):
     out = split(x)
     for i in range(len(self.block_list)):
         out = self.stack[-1 - i].inverse(out)
     out = merge(out[0], out[1])
     return out