Exemple #1
0
    def __init__(self, dim_lats, dim_hids=128, num_inds=32):
        super().__init__()

        self.encoder = nn.Sequential(View(-1, 784),
                                     WN(nn.Linear(784, dim_hids)), nn.ELU(),
                                     WN(nn.Linear(dim_hids,
                                                  dim_hids)), nn.ELU(),
                                     WN(nn.Linear(dim_hids, dim_hids)))

        self.isab1 = StackedISAB(dim_hids, dim_hids, num_inds, 4)
        self.pma = PMA(dim_hids, dim_hids, 1)
        self.fc1 = nn.Linear(dim_hids, dim_hids)

        self.posterior = Normal(dim_lats,
                                use_context=True,
                                context_enc=nn.Linear(2 * dim_hids,
                                                      2 * dim_lats))
        self.prior = FlowDistribution(
            MAF(dim_lats, dim_hids, 4, dim_context=dim_hids, inv_linear=True),
            Normal(dim_lats))

        self.decoder = nn.Sequential(
            WN(nn.Linear(dim_lats + dim_hids, dim_hids)), nn.ELU(),
            WN(nn.Linear(dim_hids, dim_hids)), nn.ELU(),
            WN(nn.Linear(dim_hids, 784)), View(-1, 1, 28, 28))
        self.likel = Bernoulli((1, 28, 28), use_context=True)

        self.mab = MAB(dim_hids, dim_hids, dim_hids)
        self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4)
        self.fc2 = nn.Linear(dim_hids, 1)
Exemple #2
0
    def __init__(self, dim_hids=256, num_inds=32):
        super().__init__()

        self.flow = FlowDistribution(
            MAF(640, dim_hids, 4, dim_context=dim_hids, inv_linear=True),
            Normal(640, use_context=False))
        self.isab1 = StackedISAB(640, dim_hids, num_inds, 4, ln=True, p=0.2)
        self.pma = PMA(dim_hids, dim_hids, 1)
        self.fc1 = nn.Linear(dim_hids, dim_hids)
        nn.init.uniform_(self.fc1.weight, a=-1e-4, b=1e-4)
        nn.init.constant_(self.fc1.bias, 0.0)

        self.mab = MAB(dim_hids, dim_hids, dim_hids)
        self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4)
        self.fc2 = nn.Linear(dim_hids, 1)
Exemple #3
0
class FindCluster(nn.Module):
    def __init__(self,
                 dim_inputs,
                 dim_hids=128,
                 num_inds=32,
                 dim_context=128,
                 num_blocks=4):
        super().__init__()

        self.flow = FlowDistribution(
            MAF(dim_inputs, dim_hids, num_blocks, dim_context=dim_context),
            Normal(dim_inputs, use_context=False))
        self.isab1 = StackedISAB(dim_inputs, dim_hids, num_inds, 4)
        self.pma = PMA(dim_hids, dim_hids, 1)
        self.fc1 = nn.Linear(dim_hids, dim_context)

        self.mab = MAB(dim_hids, dim_hids, dim_hids)
        self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4)
        self.fc2 = nn.Linear(dim_hids, 1)

    def forward(self, X, mask=None):
        H = self.isab1(X, mask=mask)
        Z = self.pma(H, mask=mask)
        context = self.fc1(Z)
        ll = self.flow.log_prob(X, context).unsqueeze(-1)

        H = self.mab(H, Z)
        logits = self.fc2(self.isab2(H, mask=mask))

        return context, ll, logits
Exemple #4
0
class FindCluster(nn.Module):
    def __init__(self, dim_hids=256, num_inds=32):
        super().__init__()

        self.flow = FlowDistribution(
            MAF(640, dim_hids, 4, dim_context=dim_hids, inv_linear=True),
            Normal(640, use_context=False))
        self.isab1 = StackedISAB(640, dim_hids, num_inds, 4, ln=True, p=0.2)
        self.pma = PMA(dim_hids, dim_hids, 1)
        self.fc1 = nn.Linear(dim_hids, dim_hids)
        nn.init.uniform_(self.fc1.weight, a=-1e-4, b=1e-4)
        nn.init.constant_(self.fc1.bias, 0.0)

        self.mab = MAB(dim_hids, dim_hids, dim_hids)
        self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4)
        self.fc2 = nn.Linear(dim_hids, 1)

    def forward(self, X, mask=None):
        H = self.isab1(X, mask=mask)
        Z = self.pma(H, mask=mask)
        context = self.fc1(Z)
        ll = self.flow.log_prob(X, context).unsqueeze(-1) / 640.0

        H = self.mab(H, Z)
        logits = self.fc2(self.isab2(H, mask=mask))
        return context, ll, logits
Exemple #5
0
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
                # to 32 * 32 * 32
                ConvResUnit(3, 32, stride=2, weight_norm=True),
                ConvResUnit(32, 32, weight_norm=True),

                # to 64 * 16 * 16
                ConvResUnit(32, 64, stride=2, weight_norm=True),
                ConvResUnit(64, 64, weight_norm=True),

                # to 64 * 8 * 8
                ConvResUnit(64, 64, stride=2, weight_norm=True),
                ConvResUnit(64, 64, weight_norm=True),

                # to (64 + 64) * 4 * 4
                ConvResUnit(64, 64, stride=2, weight_norm=True),
                ConvResUnit(64, 128, weight_norm=True)
                )

        self.posterior = Normal((64, 4, 4), use_context=True)
        #self.posterior = SimpleFlowDist((64, 4, 4), 4, use_context=True)
        self.prior = SimpleFlowDist((64, 4, 4), 4, use_context=False)

        self.decoder = nn.Sequential(
                # to 64 * 8 * 8
                DeconvResUnit(64, 64, weight_norm=True),
                DeconvResUnit(64, 64, stride=2, weight_norm=True),

                # to 64 * 16 * 16
                DeconvResUnit(64, 64, weight_norm=True),
                DeconvResUnit(64, 64, stride=2, weight_norm=True),

                # to 32 * 32 * 32
                DeconvResUnit(64, 64, weight_norm=True),
                DeconvResUnit(64, 32, stride=2, weight_norm=True),

                # to (3 + 3) * 64 * 64
                DeconvResUnit(32, 32, weight_norm=True),
                DeconvResUnit(32, 6, stride=2)
                )

        self.likel = FlowDistribution(Dequantize(),
                Normal((3, 64, 64), use_context=True))
Exemple #6
0
    def __init__(self,
                 dim_inputs,
                 dim_hids=128,
                 num_inds=32,
                 dim_context=128,
                 num_blocks=4):
        super().__init__()

        self.flow = FlowDistribution(
            MAF(dim_inputs, dim_hids, num_blocks, dim_context=dim_context),
            Normal(dim_inputs, use_context=False))
        self.isab1 = StackedISAB(dim_inputs, dim_hids, num_inds, 4)
        self.pma = PMA(dim_hids, dim_hids, 1)
        self.fc1 = nn.Linear(dim_hids, dim_context)

        self.mab = MAB(dim_hids, dim_hids, dim_hids)
        self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4)
        self.fc2 = nn.Linear(dim_hids, 1)
Exemple #7
0
class FindCluster(nn.Module):
    def __init__(self, dim_lats, dim_hids=128, num_inds=32):
        super().__init__()

        self.encoder = nn.Sequential(View(-1, 784),
                                     WN(nn.Linear(784, dim_hids)), nn.ELU(),
                                     WN(nn.Linear(dim_hids,
                                                  dim_hids)), nn.ELU(),
                                     WN(nn.Linear(dim_hids, dim_hids)))

        self.isab1 = StackedISAB(dim_hids, dim_hids, num_inds, 4)
        self.pma = PMA(dim_hids, dim_hids, 1)
        self.fc1 = nn.Linear(dim_hids, dim_hids)

        self.posterior = Normal(dim_lats,
                                use_context=True,
                                context_enc=nn.Linear(2 * dim_hids,
                                                      2 * dim_lats))
        self.prior = FlowDistribution(
            MAF(dim_lats, dim_hids, 4, dim_context=dim_hids, inv_linear=True),
            Normal(dim_lats))

        self.decoder = nn.Sequential(
            WN(nn.Linear(dim_lats + dim_hids, dim_hids)), nn.ELU(),
            WN(nn.Linear(dim_hids, dim_hids)), nn.ELU(),
            WN(nn.Linear(dim_hids, 784)), View(-1, 1, 28, 28))
        self.likel = Bernoulli((1, 28, 28), use_context=True)

        self.mab = MAB(dim_hids, dim_hids, dim_hids)
        self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4)
        self.fc2 = nn.Linear(dim_hids, 1)

    def forward(self, X, mask=None):
        B, N, C, H, W = X.shape
        x = X.view(B * N, C, H, W)
        h_enc = self.encoder(x)
        H_enc = self.isab1(h_enc.view(B, N, -1), mask=mask)
        Z = self.pma(H_enc, mask=mask)
        context = self.fc1(Z).repeat(1, N, 1).view(B * N, -1)

        if self.training:
            z, log_q = self.posterior.sample(
                context=torch.cat([h_enc, context], -1))
        else:
            z, log_q = self.posterior.mean(
                context=torch.cat([h_enc, context], -1))
        log_p = self.prior.log_prob(z, context=context)
        kld = (log_q - log_p).view(B, N, -1)

        h_dec = self.decoder(torch.cat([z, context], -1))
        ll = self.likel.log_prob(x, context=h_dec).view(B, N, -1) - kld
        ll /= C * H * W

        H_dec = self.mab(H_enc, Z)
        logits = self.fc2(self.isab2(H_dec, mask=mask))

        return context, ll, logits
Exemple #8
0
    def __init__(self,
                 num_filters=32,
                 dim_lats=128,
                 dim_hids=256,
                 dim_context=256,
                 num_inds=32):
        super().__init__()

        C = num_filters
        self.enc = nn.Sequential(nn.Conv2d(3, C, 3, stride=2),
                                 nn.BatchNorm2d(C), nn.ReLU(),
                                 nn.Conv2d(C, 2 * C, 3, stride=2),
                                 nn.BatchNorm2d(2 * C), nn.ReLU(),
                                 nn.Conv2d(2 * C, 4 * C, 3), Flatten())

        self.isab1 = StackedISAB(4 * C * 4 * 4, dim_hids, num_inds, 4)
        self.pma = PMA(dim_hids, dim_hids, 1)
        self.fc1 = nn.Linear(dim_hids, dim_context)

        self.posterior = Normal(dim_lats,
                                use_context=True,
                                context_enc=nn.Linear(
                                    4 * C * 4 * 4 + dim_context, 2 * dim_lats))
        self.prior = FlowDistribution(
            MAF(dim_lats,
                dim_hids,
                6,
                dim_context=dim_context,
                inv_linear=True), Normal(dim_lats))

        self.dec = nn.Sequential(
            nn.Linear(dim_lats + dim_context, 4 * C * 4 * 4), nn.ReLU(),
            View(-1, 4 * C, 4, 4),
            nn.ConvTranspose2d(4 * C, 2 * C, 3, stride=2, padding=1),
            nn.BatchNorm2d(2 * C), nn.ReLU(),
            nn.ConvTranspose2d(2 * C, C, 3, stride=2, padding=1),
            nn.BatchNorm2d(C), nn.ReLU(),
            nn.ConvTranspose2d(C, 3, 3, stride=2, output_padding=1),
            View(-1, 3, 28, 28))
        self.likel = Bernoulli((3, 28, 28), use_context=True)

        self.mab = MAB(dim_hids, dim_hids, dim_hids)
        self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4)
        self.fc2 = nn.Linear(dim_hids, 1)
Exemple #9
0
class AnchoredFilteringNetwork(nn.Module):
    def __init__(self,
                 dim_inputs,
                 dim_hids=128,
                 num_inds=32,
                 dim_context=128,
                 num_blocks=4):
        super().__init__()
        self.flow = FlowDistribution(
            MAF(dim_inputs, dim_hids, num_blocks, dim_context=dim_context),
            Normal(dim_inputs, use_context=False))

        self.mab1 = MAB(dim_inputs, dim_inputs, dim_hids)
        self.isab1 = StackedISAB(dim_hids, dim_hids, num_inds, 4)

        self.pma = PMA(dim_hids, dim_hids, 1)
        self.fc1 = nn.Linear(dim_hids, dim_context)

        self.mab2 = MAB(dim_hids, dim_context, dim_hids)
        self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4)
        self.fc2 = nn.Linear(dim_hids, 1)

    def forward(self, X, anc_idxs, mask=None):
        # encode data
        xa = X[torch.arange(X.shape[0]), anc_idxs].unsqueeze(-2)
        H_Xa = self.isab1(self.mab1(X, xa), mask=mask)

        # extract params of clusters including anchors
        H_theta = self.pma(H_Xa, mask=mask)
        theta = self.fc1(H_theta)
        ll = self.flow.log_prob(X, theta)
        theta = theta.squeeze(-2)

        # extract membership vector logits
        H_m = self.mab2(H_Xa, H_theta)
        H_m = self.isab2(H_m, mask=mask)
        logits = self.fc2(H_m).squeeze(-1)

        return {'theta': theta, 'll': ll, 'logits': logits}
Exemple #10
0
def train_maf(X):
    flow = FlowDistribution(MAF(2, 128, 4), Normal(2)).cuda()
    optimizer = optim.Adam(flow.parameters(), lr=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer=optimizer,
        milestones=[int(r * args.num_steps) for r in [.3, .6]],
        gamma=0.2)

    for i in range(1, args.num_steps + 1):
        optimizer.zero_grad()
        loss = -flow.log_prob(X).mean()
        loss.backward()
        nn.utils.clip_grad_norm_(flow.parameters(), args.clip)

        if i % 1000 == 0:
            print('iter {}, lr {:.3e}, ll {}'.format(
                i, optimizer.param_groups[0]['lr'], -loss.item()))

        optimizer.step()
        scheduler.step()

    return flow.log_prob(X).mean()
Exemple #11
0
class FilteringNetwork(nn.Module):
    def __init__(self,
                 num_filters=32,
                 dim_lats=128,
                 dim_hids=256,
                 dim_context=256,
                 num_inds=32):
        super().__init__()

        C = num_filters
        self.enc = nn.Sequential(nn.Conv2d(3, C, 3, stride=2),
                                 nn.BatchNorm2d(C), nn.ReLU(),
                                 nn.Conv2d(C, 2 * C, 3, stride=2),
                                 nn.BatchNorm2d(2 * C), nn.ReLU(),
                                 nn.Conv2d(2 * C, 4 * C, 3), Flatten())

        self.isab1 = StackedISAB(4 * C * 4 * 4, dim_hids, num_inds, 4)
        self.pma = PMA(dim_hids, dim_hids, 1)
        self.fc1 = nn.Linear(dim_hids, dim_context)

        self.posterior = Normal(dim_lats,
                                use_context=True,
                                context_enc=nn.Linear(
                                    4 * C * 4 * 4 + dim_context, 2 * dim_lats))
        self.prior = FlowDistribution(
            MAF(dim_lats,
                dim_hids,
                6,
                dim_context=dim_context,
                inv_linear=True), Normal(dim_lats))

        self.dec = nn.Sequential(
            nn.Linear(dim_lats + dim_context, 4 * C * 4 * 4), nn.ReLU(),
            View(-1, 4 * C, 4, 4),
            nn.ConvTranspose2d(4 * C, 2 * C, 3, stride=2, padding=1),
            nn.BatchNorm2d(2 * C), nn.ReLU(),
            nn.ConvTranspose2d(2 * C, C, 3, stride=2, padding=1),
            nn.BatchNorm2d(C), nn.ReLU(),
            nn.ConvTranspose2d(C, 3, 3, stride=2, output_padding=1),
            View(-1, 3, 28, 28))
        self.likel = Bernoulli((3, 28, 28), use_context=True)

        self.mab = MAB(dim_hids, dim_hids, dim_hids)
        self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4)
        self.fc2 = nn.Linear(dim_hids, 1)

    def forward(self, X, mask=None, return_z=False):
        B, N, C, H, W = X.shape
        x = X.view(B * N, C, H, W)
        h_enc = self.enc(x)

        H_X = self.isab1(h_enc.view(B, N, -1), mask=mask)
        H_theta = self.pma(H_X, mask=mask)
        theta = self.fc1(H_theta)
        theta_ = theta.repeat(1, N, 1).view(B * N, -1)

        z, logq = self.posterior.sample(context=torch.cat([h_enc, theta_], -1))
        logp = self.prior.log_prob(z, context=theta_)
        kld = (logq - logp).view(B, N)

        h_dec = self.dec(torch.cat([z, theta_], -1))
        ll = self.likel.log_prob(x, context=h_dec).view(B, N) - kld
        ll /= H * W

        H_dec = self.mab(H_X, H_theta)
        logits = self.fc2(self.isab2(H_dec, mask=mask)).squeeze(-1)

        outputs = {'ll': ll, 'theta': theta, 'logits': logits}
        if return_z:
            outputs['z'] = z
        return outputs
Exemple #12
0
class VAE(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
                # to 32 * 32 * 32
                ConvResUnit(3, 32, stride=2, weight_norm=True),
                ConvResUnit(32, 32, weight_norm=True),

                # to 64 * 16 * 16
                ConvResUnit(32, 64, stride=2, weight_norm=True),
                ConvResUnit(64, 64, weight_norm=True),

                # to 64 * 8 * 8
                ConvResUnit(64, 64, stride=2, weight_norm=True),
                ConvResUnit(64, 64, weight_norm=True),

                # to (64 + 64) * 4 * 4
                ConvResUnit(64, 64, stride=2, weight_norm=True),
                ConvResUnit(64, 128, weight_norm=True)
                )

        self.posterior = Normal((64, 4, 4), use_context=True)
        #self.posterior = SimpleFlowDist((64, 4, 4), 4, use_context=True)
        self.prior = SimpleFlowDist((64, 4, 4), 4, use_context=False)

        self.decoder = nn.Sequential(
                # to 64 * 8 * 8
                DeconvResUnit(64, 64, weight_norm=True),
                DeconvResUnit(64, 64, stride=2, weight_norm=True),

                # to 64 * 16 * 16
                DeconvResUnit(64, 64, weight_norm=True),
                DeconvResUnit(64, 64, stride=2, weight_norm=True),

                # to 32 * 32 * 32
                DeconvResUnit(64, 64, weight_norm=True),
                DeconvResUnit(64, 32, stride=2, weight_norm=True),

                # to (3 + 3) * 64 * 64
                DeconvResUnit(32, 32, weight_norm=True),
                DeconvResUnit(32, 6, stride=2)
                )

        self.likel = FlowDistribution(Dequantize(),
                Normal((3, 64, 64), use_context=True))

    def forward(self, x):
        _, C, H, W = x.shape
        h_enc = self.encoder(x)
        z, log_q = self.posterior.sample(context=h_enc)
        log_p = self.prior.log_prob(z)
        kld = (log_q - log_p).mean()/(C*H*W)
        h_dec = self.decoder(z)
        ll = self.likel.log_prob(x, context=h_dec).mean()/(C*H*W)
        return ll, kld

    def generate(self, num_samples, device='cpu'):
        z, _ = self.prior.sample(num_samples, device=device)
        h_dec = self.decoder(z)
        x, _ = self.likel.sample(context=h_dec)
        return x

    def reconstruct(self, x):
        h_enc = self.encoder(x)
        z, _ = self.posterior.mean(context=h_enc)
        h_dec = self.decoder(z)
        x, _ = self.likel.mean(context=h_dec)
        return x