Ejemplo n.º 1
0
    def __init__(self,
                 in_channels=1,
                 out_channels=1,
                 n_layers_enc=1,
                 n_layers_dec=1,
                 latent_dim=10):
        super(ReVAE, self).__init__()

        # =============================
        self.fc1 = nn.Linear(3072, 400)
        self.fc3 = nn.Linear(10, 400)
        self.fc4 = nn.Linear(400, 3072)
        # =============================

        # Encoder
        self.conv1 = nn.Conv2d(
            3, 32,
            3)  # go to 32 channels such that reversible blocks can split it.

        # f and g must both be a nn.Module whos output has the same shape as its input
        f_func_enc = nn.Sequential(nn.Conv2d(16, 16, 3, padding=1), nn.ReLU(),
                                   nn.Conv2d(16, 16, 3, padding=1))
        g_func_enc = nn.Sequential(nn.Conv2d(16, 16, 3, padding=1), nn.ReLU(),
                                   nn.Conv2d(16, 16, 3, padding=1))

        blocks_enc = [
            rv.ReversibleBlock(f_func_enc, g_func_enc)
            for i in range(n_layers_enc)
        ]

        self.conv2 = nn.Conv2d(32, 3, 3, padding=1)

        self.sequence_enc = rv.ReversibleSequence(nn.ModuleList(blocks_enc))

        self.fc21 = nn.Linear(2700, 10)
        self.fc22 = nn.Linear(2700, 10)

        # Decoder
        self.lin = nn.Linear(10, 7 * 7 * 32)
        self.conv3 = nn.ConvTranspose2d(32, 32, kernel_size=3, stride=(2, 2))
        self.conv4 = nn.ConvTranspose2d(32,
                                        32,
                                        kernel_size=3,
                                        stride=(2, 2),
                                        output_padding=1)

        blocks_dec = [
            rv.ReversibleBlock(f_func_enc, g_func_enc)
            for i in range(n_layers_dec)
        ]

        # pack all reversible blocks into a reversible sequence
        self.sequence_dec = rv.ReversibleSequence(nn.ModuleList(blocks_dec))

        self.last = nn.Conv2d(32, 3, 3, padding=1)
Ejemplo n.º 2
0
    def __init__(self, input_dim, output_dim, reversible_depth=3, kernel=3):
        super(ReversibleSequence, self).__init__()

        if input_dim != output_dim:
            self.inital_conv = Conv2D(input_dim, output_dim, kernel_size=1)
        else:
            self.inital_conv = nn.Identity()

        blocks = []
        for i in range(reversible_depth):

            #f and g must both be a nn.Module whos output has the same shape as its input
            f_func = nn.Sequential(
                Conv2D(output_dim // 2,
                       output_dim // 2,
                       kernel_size=kernel,
                       padding=1))
            g_func = nn.Sequential(
                Conv2D(output_dim // 2,
                       output_dim // 2,
                       kernel_size=kernel,
                       padding=1))

            #we construct a reversible block with our F and G functions
            blocks.append(rv.ReversibleBlock(f_func, g_func))

        #pack all reversible blocks into a reversible sequence
        self.sequence = rv.ReversibleSequence(nn.ModuleList(blocks))
Ejemplo n.º 3
0
    def _make_layer(self, block, planes, num_blocks, down):
        # strides = [stride] + [1]*(num_blocks-1)
        # layers = []
        # for stride in strides:
        #     layers.append(block(self.in_planes, planes, stride))
        #     self.in_planes = planes * block.expansion
        # return nn.Sequential(*layers)

        self.in_planes = planes
        layers = []
        for blc in range(num_blocks):
            block_in_planes = self.in_planes // 2
            fblock = block(block_in_planes, block_in_planes)
            gblock = block(block_in_planes, block_in_planes)

            layers.append(rv.ReversibleBlock(fblock, gblock))

        revseq = rv.ReversibleSequence(nn.ModuleList(layers))

        layers = down
        layers.append(revseq)
        # layers += down
        # layers.append(torch.nn.AvgPool2d(2,2))
        # layers.append(pad((planes - self.in_planes)//2))

        return nn.Sequential(*layers)
Ejemplo n.º 4
0
    def __init__(self,
                 dim,
                 depth,
                 heads,
                 mlp_dim,
                 dropout,
                 rezero=False,
                 attn='XCA'):
        super().__init__()

        attn_fn = XCA if attn == 'XCA' else Attention
        self.entry = nn.Parameter(torch.FloatTensor([0]))
        blocks = []
        for _ in range(depth):
            if rezero:
                f_func = RevZero(attn_fn(dim, heads=heads, dropout=dropout))
                g_func = RevZero(FeedForward(dim, mlp_dim, dropout=dropout))
            else:
                f_func = PreNorm(dim, attn_fn(dim,
                                              heads=heads,
                                              dropout=dropout))
                g_func = PreNorm(dim, FeedForward(dim,
                                                  mlp_dim,
                                                  dropout=dropout))
            block = rv.ReversibleBlock(f_func, g_func)
            blocks.append(block)

        self.layers = rv.ReversibleSequence(nn.ModuleList(blocks))
Ejemplo n.º 5
0
    def __init__(self, n_layers):
        super(Sequence, self).__init__()
        # f and g must both be a nn.Module whos output has the same shape as its input
        f_func_enc = nn.Sequential(nn.Conv2d(128, 128, 3,
                                             padding=1), nn.ReLU(),
                                   nn.Conv2d(128, 128, 3, padding=1))
        g_func_enc = nn.Sequential(nn.Conv2d(128, 128, 3,
                                             padding=1), nn.ReLU(),
                                   nn.Conv2d(128, 128, 3, padding=1))

        blocks_enc = [
            rv.ReversibleBlock(f_func_enc, g_func_enc) for i in range(n_layers)
        ]
        self.sequence_enc = rv.ReversibleSequence(nn.ModuleList(blocks_enc))
Ejemplo n.º 6
0
def RevConv(ni, nf, ks, stride, padding):
    assert ni == nf and stride == 1
    f_func = conv_layer(ni//2, nf//2, ks, stride=stride, padding=padding)
    g_func = conv_layer(ni//2, nf//2, ks, stride=stride, padding=padding)
    layers = nn.ModuleList([rv.ReversibleBlock(f_func, g_func)])
    return rv.ReversibleSequence(layers, eagerly_discard_variables = True)