def __init__(self, block, keep_input, enabled=True): super().__init__() self.invertible_block = memcnn.InvertibleModuleWrapper( fn=memcnn.AdditiveCoupling(block), keep_input=keep_input, keep_input_inverse=keep_input, disable=not enabled, )
def __init__(self, channel): super().__init__() invertible_module = memcnn.AdditiveCoupling( Fm=nn.Sequential( nn.utils.weight_norm(nn.Linear(channel, channel)), nn.LeakyReLU(inplace=True), ), Gm=nn.Sequential( nn.utils.weight_norm(nn.Linear(channel, channel)), nn.LeakyReLU(inplace=True), )) self.model = memcnn.InvertibleModuleWrapper(fn=invertible_module)
def __init__(self, dim=None, out_dim=None, dim_mults=(1, 2, 4, 8), feature_factor=1, size=32, noise=0.3): super(Unet, self).__init__() self.size = size depth, features = self._get_parameters() features = int(features * feature_factor) self.embed = Embedding(features) self.feature_factor = feature_factor def layer(dilation): return BasicBlock(features // 2, True, noise, dilation + 1) self.blocks = torch.nn.ModuleList([memcnn.InvertibleModuleWrapper(memcnn.AdditiveCoupling(layer(i), layer(i))) for i in range(depth)]) self.dense = torch.nn.Parameter(torch.empty(2 + depth, features, features)) for i in range(2+depth): torch.nn.init.orthogonal_(self.dense[i]) self.features3 = features - 3
# application of the operation(s) the normal way model_normal = ExampleOperation(channels=10) model_normal.eval() Y = model_normal(X) # turn the ExampleOperation invertible using an additive coupling invertible_module = memcnn.AdditiveCoupling( Fm=ExampleOperation(channels=10 // 2), Gm=ExampleOperation(channels=10 // 2)) # test that it is actually a valid invertible module (has a valid inverse method) assert memcnn.is_invertible_module(invertible_module, test_input_shape=X.shape) # wrap our invertible_module using the InvertibleModuleWrapper and benefit from memory savings during training invertible_module_wrapper = memcnn.InvertibleModuleWrapper( fn=invertible_module, keep_input=True, keep_input_inverse=True) # by default the module is set to training, the following sets this to evaluation # note that this is required to pass input tensors to the model with requires_grad=False (inference only) invertible_module_wrapper.eval() # test that the wrapped module is also a valid invertible module assert memcnn.is_invertible_module(invertible_module_wrapper, test_input_shape=X.shape) # compute the forward pass using the wrapper Y2 = invertible_module_wrapper.forward(X) # the input (X) can be approximated (X2) by applying the inverse method of the wrapper on Y2 X2 = invertible_module_wrapper.inverse(Y2)