Beispiel #1
0
 def forward(self, x, ignore_logdet=False):
     Fx = self.res_branch(x)
     if (self.numTraceSamples == 0
             and self.numSeriesTerms == 0) or ignore_logdet:
         trace = torch.tensor(0.)
     else:
         trace = power_series_matrix_logarithm_trace(
             Fx, x, self.numSeriesTerms, self.numTraceSamples)
     x = x + Fx
     return x, trace
Beispiel #2
0
    def forward(self, x, ignore_logdet=False):
        """ bijective or injective block forward """
        if self.stride == 2:
            x = self.squeeze.forward(x)

        if self.actnorm is not None:
            x, an_logdet = self.actnorm(x)
        else:
            an_logdet = 0.0

        Fx = self.bottleneck_block(x)
        # Compute approximate trace for use in training
        if (self.numTraceSamples == 0 and self.numSeriesTerms == 0) or ignore_logdet:
            trace = torch.tensor(0.)
        else:
            trace = power_series_matrix_logarithm_trace(Fx, x, self.numSeriesTerms, self.numTraceSamples)

        # add residual to output
        y = Fx + x
        return y, trace + an_logdet