def feature_extract(self, z, f): #TODO: The original OpenAI glow code uses gradient checkpointing, an efficient way # of reducing peak memory consumption. Test adding gradient checkpointing to reduce # memory consumption. # h = torch.utils.checkpoint.checkpoint(f, z) # change the line below to this. h = f(z) shift, scale = thops.split_feature(h, "cross") #TODO: test with tanh instead of sigmoid like in RealNVP scale = (torch.sigmoid(scale + 2.) + self.affine_eps) return scale, shift
def split2d_prior(self, z, ft): if ft is not None: z = torch.cat([z, ft], dim=1) h = self.conv(z) return thops.split_feature(h, "cross")