def inverse(self, out_bij): """ irevnet inverse """ out = split(out_bij) for i in range(len(self.stack)): out = self.stack[-1-i].inverse(out) out = merge(out[0],out[1]) x = self.init_psi.inverse(out) return x
def forward(self, x): """ irevnet forward """ n = self.in_ch // 2 if self.init_ds != 0: x = self.init_psi.forward(x) out = (x[:, :n, :, :], x[:, n:, :, :]) for block in self.stack: out = block.forward(out) out = merge(out[0], out[1]) out = F.relu(self.bn1(self.conv1(out))) out = self.tanh(self.conv2(out)) return out
def forward_features(self, x): """ irevnet forward, last layer features""" n = self.in_ch // 2 if self.init_ds != 0: x = self.init_psi.forward(x) out = (x[:, :n, :, :], x[:, n:, :, :]) for block in self.stack: out = block.forward(out) out_bij = merge(out[0], out[1]) out = F.relu(self.bn1(out_bij)) out = F.avg_pool2d(out, self.ds) out = out.view(out.size(0), -1) return out
def forward(self, x): """ bijective or injective block forward """ if self.pad != 0 and self.stride == 1: x = merge(x[0], x[1]) x = self.inj_pad.forward(x) x1, x2 = split(x) x = (x1, x2) x1 = x[0] x2 = x[1] Fx2 = self.bottleneck_block(x2) if self.stride == 2: x1 = self.psi.forward(x1) x2 = self.psi.forward(x2) y1 = Fx2 + x1 return (x2, y1)
def forward(self, x): stratx = x n = self.inchannel // 2 out = (x[:, :n, :, :], x[:, n:, :, :]) for i in range(len(self.block_list)): block = self.block_list[i] out = block.forward(out) out_bij = merge(out[0], out[1]) out_bij = out_bij + stratx return out_bij
def inverse(self, x): """ bijective or injecitve block inverse """ x2, y1 = x[0], x[1] if self.stride == 2: x2 = self.psi.inverse(x2) Fx2 = -self.bottleneck_block(x2) x1 = Fx2 + y1 if self.stride == 2: x1 = self.psi.inverse(x1) if self.pad != 0 and self.stride == 1: x = merge(x1, x2) x = self.inj_pad.inverse(x) x1, x2 = split(x) x = (x1, x2) else: x = (x1, x2) return x
def inverse(self,x): out = split(x) for i in range(len(self.block_list)): out = self.stack[-1 - i].inverse(out) out = merge(out[0], out[1]) return out