def __init__(self, dim_lats, dim_hids=128, num_inds=32): super().__init__() self.encoder = nn.Sequential(View(-1, 784), WN(nn.Linear(784, dim_hids)), nn.ELU(), WN(nn.Linear(dim_hids, dim_hids)), nn.ELU(), WN(nn.Linear(dim_hids, dim_hids))) self.isab1 = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.pma = PMA(dim_hids, dim_hids, 1) self.fc1 = nn.Linear(dim_hids, dim_hids) self.posterior = Normal(dim_lats, use_context=True, context_enc=nn.Linear(2 * dim_hids, 2 * dim_lats)) self.prior = FlowDistribution( MAF(dim_lats, dim_hids, 4, dim_context=dim_hids, inv_linear=True), Normal(dim_lats)) self.decoder = nn.Sequential( WN(nn.Linear(dim_lats + dim_hids, dim_hids)), nn.ELU(), WN(nn.Linear(dim_hids, dim_hids)), nn.ELU(), WN(nn.Linear(dim_hids, 784)), View(-1, 1, 28, 28)) self.likel = Bernoulli((1, 28, 28), use_context=True) self.mab = MAB(dim_hids, dim_hids, dim_hids) self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.fc2 = nn.Linear(dim_hids, 1)
def __init__(self, dim_hids=256, num_inds=32): super().__init__() self.flow = FlowDistribution( MAF(640, dim_hids, 4, dim_context=dim_hids, inv_linear=True), Normal(640, use_context=False)) self.isab1 = StackedISAB(640, dim_hids, num_inds, 4, ln=True, p=0.2) self.pma = PMA(dim_hids, dim_hids, 1) self.fc1 = nn.Linear(dim_hids, dim_hids) nn.init.uniform_(self.fc1.weight, a=-1e-4, b=1e-4) nn.init.constant_(self.fc1.bias, 0.0) self.mab = MAB(dim_hids, dim_hids, dim_hids) self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.fc2 = nn.Linear(dim_hids, 1)
class FindCluster(nn.Module): def __init__(self, dim_inputs, dim_hids=128, num_inds=32, dim_context=128, num_blocks=4): super().__init__() self.flow = FlowDistribution( MAF(dim_inputs, dim_hids, num_blocks, dim_context=dim_context), Normal(dim_inputs, use_context=False)) self.isab1 = StackedISAB(dim_inputs, dim_hids, num_inds, 4) self.pma = PMA(dim_hids, dim_hids, 1) self.fc1 = nn.Linear(dim_hids, dim_context) self.mab = MAB(dim_hids, dim_hids, dim_hids) self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.fc2 = nn.Linear(dim_hids, 1) def forward(self, X, mask=None): H = self.isab1(X, mask=mask) Z = self.pma(H, mask=mask) context = self.fc1(Z) ll = self.flow.log_prob(X, context).unsqueeze(-1) H = self.mab(H, Z) logits = self.fc2(self.isab2(H, mask=mask)) return context, ll, logits
class FindCluster(nn.Module): def __init__(self, dim_hids=256, num_inds=32): super().__init__() self.flow = FlowDistribution( MAF(640, dim_hids, 4, dim_context=dim_hids, inv_linear=True), Normal(640, use_context=False)) self.isab1 = StackedISAB(640, dim_hids, num_inds, 4, ln=True, p=0.2) self.pma = PMA(dim_hids, dim_hids, 1) self.fc1 = nn.Linear(dim_hids, dim_hids) nn.init.uniform_(self.fc1.weight, a=-1e-4, b=1e-4) nn.init.constant_(self.fc1.bias, 0.0) self.mab = MAB(dim_hids, dim_hids, dim_hids) self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.fc2 = nn.Linear(dim_hids, 1) def forward(self, X, mask=None): H = self.isab1(X, mask=mask) Z = self.pma(H, mask=mask) context = self.fc1(Z) ll = self.flow.log_prob(X, context).unsqueeze(-1) / 640.0 H = self.mab(H, Z) logits = self.fc2(self.isab2(H, mask=mask)) return context, ll, logits
def __init__(self): super().__init__() self.encoder = nn.Sequential( # to 32 * 32 * 32 ConvResUnit(3, 32, stride=2, weight_norm=True), ConvResUnit(32, 32, weight_norm=True), # to 64 * 16 * 16 ConvResUnit(32, 64, stride=2, weight_norm=True), ConvResUnit(64, 64, weight_norm=True), # to 64 * 8 * 8 ConvResUnit(64, 64, stride=2, weight_norm=True), ConvResUnit(64, 64, weight_norm=True), # to (64 + 64) * 4 * 4 ConvResUnit(64, 64, stride=2, weight_norm=True), ConvResUnit(64, 128, weight_norm=True) ) self.posterior = Normal((64, 4, 4), use_context=True) #self.posterior = SimpleFlowDist((64, 4, 4), 4, use_context=True) self.prior = SimpleFlowDist((64, 4, 4), 4, use_context=False) self.decoder = nn.Sequential( # to 64 * 8 * 8 DeconvResUnit(64, 64, weight_norm=True), DeconvResUnit(64, 64, stride=2, weight_norm=True), # to 64 * 16 * 16 DeconvResUnit(64, 64, weight_norm=True), DeconvResUnit(64, 64, stride=2, weight_norm=True), # to 32 * 32 * 32 DeconvResUnit(64, 64, weight_norm=True), DeconvResUnit(64, 32, stride=2, weight_norm=True), # to (3 + 3) * 64 * 64 DeconvResUnit(32, 32, weight_norm=True), DeconvResUnit(32, 6, stride=2) ) self.likel = FlowDistribution(Dequantize(), Normal((3, 64, 64), use_context=True))
def __init__(self, dim_inputs, dim_hids=128, num_inds=32, dim_context=128, num_blocks=4): super().__init__() self.flow = FlowDistribution( MAF(dim_inputs, dim_hids, num_blocks, dim_context=dim_context), Normal(dim_inputs, use_context=False)) self.isab1 = StackedISAB(dim_inputs, dim_hids, num_inds, 4) self.pma = PMA(dim_hids, dim_hids, 1) self.fc1 = nn.Linear(dim_hids, dim_context) self.mab = MAB(dim_hids, dim_hids, dim_hids) self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.fc2 = nn.Linear(dim_hids, 1)
class FindCluster(nn.Module): def __init__(self, dim_lats, dim_hids=128, num_inds=32): super().__init__() self.encoder = nn.Sequential(View(-1, 784), WN(nn.Linear(784, dim_hids)), nn.ELU(), WN(nn.Linear(dim_hids, dim_hids)), nn.ELU(), WN(nn.Linear(dim_hids, dim_hids))) self.isab1 = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.pma = PMA(dim_hids, dim_hids, 1) self.fc1 = nn.Linear(dim_hids, dim_hids) self.posterior = Normal(dim_lats, use_context=True, context_enc=nn.Linear(2 * dim_hids, 2 * dim_lats)) self.prior = FlowDistribution( MAF(dim_lats, dim_hids, 4, dim_context=dim_hids, inv_linear=True), Normal(dim_lats)) self.decoder = nn.Sequential( WN(nn.Linear(dim_lats + dim_hids, dim_hids)), nn.ELU(), WN(nn.Linear(dim_hids, dim_hids)), nn.ELU(), WN(nn.Linear(dim_hids, 784)), View(-1, 1, 28, 28)) self.likel = Bernoulli((1, 28, 28), use_context=True) self.mab = MAB(dim_hids, dim_hids, dim_hids) self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.fc2 = nn.Linear(dim_hids, 1) def forward(self, X, mask=None): B, N, C, H, W = X.shape x = X.view(B * N, C, H, W) h_enc = self.encoder(x) H_enc = self.isab1(h_enc.view(B, N, -1), mask=mask) Z = self.pma(H_enc, mask=mask) context = self.fc1(Z).repeat(1, N, 1).view(B * N, -1) if self.training: z, log_q = self.posterior.sample( context=torch.cat([h_enc, context], -1)) else: z, log_q = self.posterior.mean( context=torch.cat([h_enc, context], -1)) log_p = self.prior.log_prob(z, context=context) kld = (log_q - log_p).view(B, N, -1) h_dec = self.decoder(torch.cat([z, context], -1)) ll = self.likel.log_prob(x, context=h_dec).view(B, N, -1) - kld ll /= C * H * W H_dec = self.mab(H_enc, Z) logits = self.fc2(self.isab2(H_dec, mask=mask)) return context, ll, logits
def __init__(self, num_filters=32, dim_lats=128, dim_hids=256, dim_context=256, num_inds=32): super().__init__() C = num_filters self.enc = nn.Sequential(nn.Conv2d(3, C, 3, stride=2), nn.BatchNorm2d(C), nn.ReLU(), nn.Conv2d(C, 2 * C, 3, stride=2), nn.BatchNorm2d(2 * C), nn.ReLU(), nn.Conv2d(2 * C, 4 * C, 3), Flatten()) self.isab1 = StackedISAB(4 * C * 4 * 4, dim_hids, num_inds, 4) self.pma = PMA(dim_hids, dim_hids, 1) self.fc1 = nn.Linear(dim_hids, dim_context) self.posterior = Normal(dim_lats, use_context=True, context_enc=nn.Linear( 4 * C * 4 * 4 + dim_context, 2 * dim_lats)) self.prior = FlowDistribution( MAF(dim_lats, dim_hids, 6, dim_context=dim_context, inv_linear=True), Normal(dim_lats)) self.dec = nn.Sequential( nn.Linear(dim_lats + dim_context, 4 * C * 4 * 4), nn.ReLU(), View(-1, 4 * C, 4, 4), nn.ConvTranspose2d(4 * C, 2 * C, 3, stride=2, padding=1), nn.BatchNorm2d(2 * C), nn.ReLU(), nn.ConvTranspose2d(2 * C, C, 3, stride=2, padding=1), nn.BatchNorm2d(C), nn.ReLU(), nn.ConvTranspose2d(C, 3, 3, stride=2, output_padding=1), View(-1, 3, 28, 28)) self.likel = Bernoulli((3, 28, 28), use_context=True) self.mab = MAB(dim_hids, dim_hids, dim_hids) self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.fc2 = nn.Linear(dim_hids, 1)
class AnchoredFilteringNetwork(nn.Module): def __init__(self, dim_inputs, dim_hids=128, num_inds=32, dim_context=128, num_blocks=4): super().__init__() self.flow = FlowDistribution( MAF(dim_inputs, dim_hids, num_blocks, dim_context=dim_context), Normal(dim_inputs, use_context=False)) self.mab1 = MAB(dim_inputs, dim_inputs, dim_hids) self.isab1 = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.pma = PMA(dim_hids, dim_hids, 1) self.fc1 = nn.Linear(dim_hids, dim_context) self.mab2 = MAB(dim_hids, dim_context, dim_hids) self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.fc2 = nn.Linear(dim_hids, 1) def forward(self, X, anc_idxs, mask=None): # encode data xa = X[torch.arange(X.shape[0]), anc_idxs].unsqueeze(-2) H_Xa = self.isab1(self.mab1(X, xa), mask=mask) # extract params of clusters including anchors H_theta = self.pma(H_Xa, mask=mask) theta = self.fc1(H_theta) ll = self.flow.log_prob(X, theta) theta = theta.squeeze(-2) # extract membership vector logits H_m = self.mab2(H_Xa, H_theta) H_m = self.isab2(H_m, mask=mask) logits = self.fc2(H_m).squeeze(-1) return {'theta': theta, 'll': ll, 'logits': logits}
def train_maf(X): flow = FlowDistribution(MAF(2, 128, 4), Normal(2)).cuda() optimizer = optim.Adam(flow.parameters(), lr=5e-4) scheduler = optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=[int(r * args.num_steps) for r in [.3, .6]], gamma=0.2) for i in range(1, args.num_steps + 1): optimizer.zero_grad() loss = -flow.log_prob(X).mean() loss.backward() nn.utils.clip_grad_norm_(flow.parameters(), args.clip) if i % 1000 == 0: print('iter {}, lr {:.3e}, ll {}'.format( i, optimizer.param_groups[0]['lr'], -loss.item())) optimizer.step() scheduler.step() return flow.log_prob(X).mean()
class FilteringNetwork(nn.Module): def __init__(self, num_filters=32, dim_lats=128, dim_hids=256, dim_context=256, num_inds=32): super().__init__() C = num_filters self.enc = nn.Sequential(nn.Conv2d(3, C, 3, stride=2), nn.BatchNorm2d(C), nn.ReLU(), nn.Conv2d(C, 2 * C, 3, stride=2), nn.BatchNorm2d(2 * C), nn.ReLU(), nn.Conv2d(2 * C, 4 * C, 3), Flatten()) self.isab1 = StackedISAB(4 * C * 4 * 4, dim_hids, num_inds, 4) self.pma = PMA(dim_hids, dim_hids, 1) self.fc1 = nn.Linear(dim_hids, dim_context) self.posterior = Normal(dim_lats, use_context=True, context_enc=nn.Linear( 4 * C * 4 * 4 + dim_context, 2 * dim_lats)) self.prior = FlowDistribution( MAF(dim_lats, dim_hids, 6, dim_context=dim_context, inv_linear=True), Normal(dim_lats)) self.dec = nn.Sequential( nn.Linear(dim_lats + dim_context, 4 * C * 4 * 4), nn.ReLU(), View(-1, 4 * C, 4, 4), nn.ConvTranspose2d(4 * C, 2 * C, 3, stride=2, padding=1), nn.BatchNorm2d(2 * C), nn.ReLU(), nn.ConvTranspose2d(2 * C, C, 3, stride=2, padding=1), nn.BatchNorm2d(C), nn.ReLU(), nn.ConvTranspose2d(C, 3, 3, stride=2, output_padding=1), View(-1, 3, 28, 28)) self.likel = Bernoulli((3, 28, 28), use_context=True) self.mab = MAB(dim_hids, dim_hids, dim_hids) self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.fc2 = nn.Linear(dim_hids, 1) def forward(self, X, mask=None, return_z=False): B, N, C, H, W = X.shape x = X.view(B * N, C, H, W) h_enc = self.enc(x) H_X = self.isab1(h_enc.view(B, N, -1), mask=mask) H_theta = self.pma(H_X, mask=mask) theta = self.fc1(H_theta) theta_ = theta.repeat(1, N, 1).view(B * N, -1) z, logq = self.posterior.sample(context=torch.cat([h_enc, theta_], -1)) logp = self.prior.log_prob(z, context=theta_) kld = (logq - logp).view(B, N) h_dec = self.dec(torch.cat([z, theta_], -1)) ll = self.likel.log_prob(x, context=h_dec).view(B, N) - kld ll /= H * W H_dec = self.mab(H_X, H_theta) logits = self.fc2(self.isab2(H_dec, mask=mask)).squeeze(-1) outputs = {'ll': ll, 'theta': theta, 'logits': logits} if return_z: outputs['z'] = z return outputs
class VAE(nn.Module): def __init__(self): super().__init__() self.encoder = nn.Sequential( # to 32 * 32 * 32 ConvResUnit(3, 32, stride=2, weight_norm=True), ConvResUnit(32, 32, weight_norm=True), # to 64 * 16 * 16 ConvResUnit(32, 64, stride=2, weight_norm=True), ConvResUnit(64, 64, weight_norm=True), # to 64 * 8 * 8 ConvResUnit(64, 64, stride=2, weight_norm=True), ConvResUnit(64, 64, weight_norm=True), # to (64 + 64) * 4 * 4 ConvResUnit(64, 64, stride=2, weight_norm=True), ConvResUnit(64, 128, weight_norm=True) ) self.posterior = Normal((64, 4, 4), use_context=True) #self.posterior = SimpleFlowDist((64, 4, 4), 4, use_context=True) self.prior = SimpleFlowDist((64, 4, 4), 4, use_context=False) self.decoder = nn.Sequential( # to 64 * 8 * 8 DeconvResUnit(64, 64, weight_norm=True), DeconvResUnit(64, 64, stride=2, weight_norm=True), # to 64 * 16 * 16 DeconvResUnit(64, 64, weight_norm=True), DeconvResUnit(64, 64, stride=2, weight_norm=True), # to 32 * 32 * 32 DeconvResUnit(64, 64, weight_norm=True), DeconvResUnit(64, 32, stride=2, weight_norm=True), # to (3 + 3) * 64 * 64 DeconvResUnit(32, 32, weight_norm=True), DeconvResUnit(32, 6, stride=2) ) self.likel = FlowDistribution(Dequantize(), Normal((3, 64, 64), use_context=True)) def forward(self, x): _, C, H, W = x.shape h_enc = self.encoder(x) z, log_q = self.posterior.sample(context=h_enc) log_p = self.prior.log_prob(z) kld = (log_q - log_p).mean()/(C*H*W) h_dec = self.decoder(z) ll = self.likel.log_prob(x, context=h_dec).mean()/(C*H*W) return ll, kld def generate(self, num_samples, device='cpu'): z, _ = self.prior.sample(num_samples, device=device) h_dec = self.decoder(z) x, _ = self.likel.sample(context=h_dec) return x def reconstruct(self, x): h_enc = self.encoder(x) z, _ = self.posterior.mean(context=h_enc) h_dec = self.decoder(z) x, _ = self.likel.mean(context=h_dec) return x