def inverse(self, out_bij): """ irevnet inverse """ out = split(out_bij) for i in range(len(self.stack)): out = self.stack[-1 - i].inverse(out) out = merge(out[0], out[1]) if self.init_ds != 0: x = self.init_psi.inverse(out) else: x = out return x
def forward(self, x): """ irevnet forward """ n = self.in_ch // 2 if self.init_ds != 0: x = self.init_psi.forward(x) out = (x[:, :n, :, :], x[:, n:, :, :]) for block in self.stack: out = block.forward(out) out_bij = merge(out[0], out[1]) out = F.relu(self.bn1(out_bij)) out = F.avg_pool2d(out, self.ds) out = out.view(out.size(0), -1) out = self.linear(out) return out_bij
def forward(self, x): """ bijective or injective block forward """ if self.pad != 0 and self.stride == 1: x = merge(x[0], x[1]) x = self.inj_pad.forward(x) x1, x2 = split(x) x = (x1, x2) x1 = x[0] x2 = x[1] Fx2 = self.bottleneck_block(x2) if self.stride == 2: x1 = self.psi.forward(x1) x2 = self.psi.forward(x2) y1 = Fx2 + x1 return (x2, y1)
def inverse(self, x): """ bijective or injecitve block inverse """ x2, y1 = x[0], x[1] if self.stride == 2: x2 = self.psi.inverse(x2) Fx2 = -self.bottleneck_block(x2) x1 = Fx2 + y1 if self.stride == 2: x1 = self.psi.inverse(x1) if self.pad != 0 and self.stride == 1: x = merge(x1, x2) x = self.inj_pad.inverse(x) x1, x2 = split(x) x = (x1, x2) else: x = (x1, x2) return x