def forward(self, x, ignore_logdet=False): Fx = self.res_branch(x) if (self.numTraceSamples == 0 and self.numSeriesTerms == 0) or ignore_logdet: trace = torch.tensor(0.) else: trace = power_series_matrix_logarithm_trace( Fx, x, self.numSeriesTerms, self.numTraceSamples) x = x + Fx return x, trace
def forward(self, x, ignore_logdet=False): """ bijective or injective block forward """ if self.stride == 2: x = self.squeeze.forward(x) if self.actnorm is not None: x, an_logdet = self.actnorm(x) else: an_logdet = 0.0 Fx = self.bottleneck_block(x) # Compute approximate trace for use in training if (self.numTraceSamples == 0 and self.numSeriesTerms == 0) or ignore_logdet: trace = torch.tensor(0.) else: trace = power_series_matrix_logarithm_trace(Fx, x, self.numSeriesTerms, self.numTraceSamples) # add residual to output y = Fx + x return y, trace + an_logdet