Example #1
0
 def __init__(self, block, keep_input, enabled=True):
     super().__init__()
     self.invertible_block = memcnn.InvertibleModuleWrapper(
         fn=memcnn.AdditiveCoupling(block),
         keep_input=keep_input,
         keep_input_inverse=keep_input,
         disable=not enabled,
     )
Example #2
0
    def __init__(self, channel):
        super().__init__()

        invertible_module = memcnn.AdditiveCoupling(
            Fm=nn.Sequential(
                nn.utils.weight_norm(nn.Linear(channel, channel)),
                nn.LeakyReLU(inplace=True),
            ),
            Gm=nn.Sequential(
                nn.utils.weight_norm(nn.Linear(channel, channel)),
                nn.LeakyReLU(inplace=True),
            ))
        self.model = memcnn.InvertibleModuleWrapper(fn=invertible_module)
Example #3
0
    def __init__(self, dim=None, out_dim=None, dim_mults=(1, 2, 4, 8), feature_factor=1, size=32, noise=0.3):
        super(Unet, self).__init__()

        self.size = size
        depth, features = self._get_parameters()
        features = int(features * feature_factor)
        self.embed = Embedding(features)
        self.feature_factor = feature_factor

        def layer(dilation):
            return BasicBlock(features // 2, True, noise, dilation + 1)

        self.blocks = torch.nn.ModuleList([memcnn.InvertibleModuleWrapper(memcnn.AdditiveCoupling(layer(i), layer(i)))
                                           for i in range(depth)])
        self.dense = torch.nn.Parameter(torch.empty(2 + depth, features, features))
        for i in range(2+depth):
            torch.nn.init.orthogonal_(self.dense[i])
        self.features3 = features - 3
Example #4
0
# application of the operation(s) the normal way
model_normal = ExampleOperation(channels=10)
model_normal.eval()

Y = model_normal(X)

# turn the ExampleOperation invertible using an additive coupling
invertible_module = memcnn.AdditiveCoupling(
    Fm=ExampleOperation(channels=10 // 2),
    Gm=ExampleOperation(channels=10 // 2))

# test that it is actually a valid invertible module (has a valid inverse method)
assert memcnn.is_invertible_module(invertible_module, test_input_shape=X.shape)

# wrap our invertible_module using the InvertibleModuleWrapper and benefit from memory savings during training
invertible_module_wrapper = memcnn.InvertibleModuleWrapper(
    fn=invertible_module, keep_input=True, keep_input_inverse=True)

# by default the module is set to training, the following sets this to evaluation
# note that this is required to pass input tensors to the model with requires_grad=False (inference only)
invertible_module_wrapper.eval()

# test that the wrapped module is also a valid invertible module
assert memcnn.is_invertible_module(invertible_module_wrapper,
                                   test_input_shape=X.shape)

# compute the forward pass using the wrapper
Y2 = invertible_module_wrapper.forward(X)

# the input (X) can be approximated (X2) by applying the inverse method of the wrapper on Y2
X2 = invertible_module_wrapper.inverse(Y2)