def inverse(self, out_bij): """ irevnet inverse """ out = split(out_bij) for i in range(len(self.stack)): out = self.stack[-1 - i].inverse(out) out = merge(out[0], out[1]) x = self.init_psi.inverse(out) return x
def forward(self, x): """ irevnet forward """ n = self.in_ch // 2 if self.init_ds != 0: x = self.init_psi.forward(x) out = (x[:, :n, :, :], x[:, n:, :, :]) for block in self.stack: out = block.forward(out) out_bij = merge(out[0], out[1]) out = F.relu(self.bn1(out_bij)) out = F.avg_pool2d(out, self.ds) out = out.view(out.size(0), -1) out = self.linear(out) return out, out_bij
def inverse(self, x): """ bijective or injecitve block inverse """ x2, y1 = x[0], x[1] if self.stride == 2: x2 = self.psi.inverse(x2) Fx2 = -self.bottleneck_block(x2) x1 = Fx2 + y1 if self.stride == 2: x1 = self.psi.inverse(x1) if self.pad != 0 and self.stride == 1: x = merge(x1, x2) x = self.inj_pad.inverse(x) x1, x2 = split(x) x = (x1, x2) else: x = (x1, x2) return x
def forward(self, x): n = self.inchannel // 2 out = (x[:, :n, :, :], x[:, n:, :, :]) for i in range(len(self.block_list)): block = self.block_list[i] out = block.forward(out) out_bij = merge(out[0], out[1]) # out = stratx + out_bij # if self.pooling: # out = self.psi.forward(out) # return out_bij
def inverse(self, x): out = split(x) for i in range(len(self.block_list)): out = self.stack[-1 - i].inverse(out) out = merge(out[0], out[1]) return out