def __init__(self, dim_lats, dim_hids=128, num_inds=32): super().__init__() self.encoder = nn.Sequential(View(-1, 784), WN(nn.Linear(784, dim_hids)), nn.ELU(), WN(nn.Linear(dim_hids, dim_hids)), nn.ELU(), WN(nn.Linear(dim_hids, dim_hids))) self.isab1 = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.pma = PMA(dim_hids, dim_hids, 1) self.fc1 = nn.Linear(dim_hids, dim_hids) self.posterior = Normal(dim_lats, use_context=True, context_enc=nn.Linear(2 * dim_hids, 2 * dim_lats)) self.prior = FlowDistribution( MAF(dim_lats, dim_hids, 4, dim_context=dim_hids, inv_linear=True), Normal(dim_lats)) self.decoder = nn.Sequential( WN(nn.Linear(dim_lats + dim_hids, dim_hids)), nn.ELU(), WN(nn.Linear(dim_hids, dim_hids)), nn.ELU(), WN(nn.Linear(dim_hids, 784)), View(-1, 1, 28, 28)) self.likel = Bernoulli((1, 28, 28), use_context=True) self.mab = MAB(dim_hids, dim_hids, dim_hids) self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.fc2 = nn.Linear(dim_hids, 1)
def __init__(self, mvn, dim_hids=128, num_inds=32): super().__init__() self.mvn = mvn self.isab1 = StackedISAB(mvn.dim, dim_hids, num_inds, 4) self.pma = PMA(dim_hids, dim_hids, 1) self.fc1 = nn.Linear(dim_hids, mvn.dim_params) self.mab = MAB(dim_hids, dim_hids, dim_hids) self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.fc2 = nn.Linear(dim_hids, 1)
def __init__(self, dim_hids=256, num_inds=32): super().__init__() self.flow = FlowDistribution( MAF(640, dim_hids, 4, dim_context=dim_hids, inv_linear=True), Normal(640, use_context=False)) self.isab1 = StackedISAB(640, dim_hids, num_inds, 4, ln=True, p=0.2) self.pma = PMA(dim_hids, dim_hids, 1) self.fc1 = nn.Linear(dim_hids, dim_hids) nn.init.uniform_(self.fc1.weight, a=-1e-4, b=1e-4) nn.init.constant_(self.fc1.bias, 0.0) self.mab = MAB(dim_hids, dim_hids, dim_hids) self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.fc2 = nn.Linear(dim_hids, 1)
def forward(self, X, anchor_idxs): B, N, C, H, W = X.shape H_enc = self.encoder(X.view(B * N, C, H, W)).view(B, N, -1) anchors = H_enc[torch.arange(B), anchor_idxs].unsqueeze(1) H_enc = self.mab(H_enc, anchors) return self.fc(self.isab(H_enc)) self.isab = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.fc = nn.Linear(dim_hids, 1)
def __init__(self, mvn, dim_hids=128, num_inds=32): super().__init__() self.mvn = mvn self.isab = StackedISAB(mvn.dim, dim_hids, num_inds, 4) self.apma = aPMA(dim_hids, dim_hids) self.sab = StackedSAB(dim_hids, dim_hids, 2) self.fc = nn.Linear(dim_hids, 2 + mvn.dim_params)
def __init__(self, dim_hids=128, num_inds=32): super().__init__() self.encoder = nn.Sequential(FixupResUnit(1, 16, stride=2), FixupResUnit(16, 32, stride=2), FixupResUnit(32, dim_hids, stride=2), nn.AdaptiveAvgPool2d(1)) self.isab = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.fc = nn.Linear(dim_hids, 1)
def __init__(self, dim_inputs, dim_hids=128, num_inds=32, dim_context=128, num_blocks=4): super().__init__() self.flow = FlowDistribution( MAF(dim_inputs, dim_hids, num_blocks, dim_context=dim_context), Normal(dim_inputs, use_context=False)) self.isab1 = StackedISAB(dim_inputs, dim_hids, num_inds, 4) self.pma = PMA(dim_hids, dim_hids, 1) self.fc1 = nn.Linear(dim_hids, dim_context) self.mab = MAB(dim_hids, dim_hids, dim_hids) self.isab2 = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.fc2 = nn.Linear(dim_hids, 1)
def __init__(self, dim_hids=128, num_inds=32): super().__init__() self.encoder = nn.Sequential(FixupResUnit(1, 32), nn.MaxPool2d(2), FixupResUnit(32, 64), nn.MaxPool2d(2), FixupResUnit(64, dim_hids), nn.AdaptiveAvgPool2d(1)) self.isab = StackedISAB(dim_hids, dim_hids, num_inds, 4, p=0.3) self.fc = nn.Linear(dim_hids, 1)
def __init__(self, dim_hids=256, num_inds=32): super().__init__() self.encoder = nn.Sequential(FixupResUnit(3, 32, stride=2), FixupResUnit(32, 32), FixupResUnit(32, 64, stride=2), FixupResUnit(64, 64), FixupResUnit(64, 128, stride=2), FixupResUnit(128, dim_hids), nn.AdaptiveAvgPool2d(1)) self.mab = MAB(dim_hids, dim_hids, dim_hids) self.isab = StackedISAB(dim_hids, dim_hids, num_inds, 4) self.fc = nn.Linear(dim_hids, 1)
def __init__(self, dim_hids=256, num_inds=32): super().__init__() self.isab = StackedISAB(640, dim_hids, num_inds, 6, p=0.2, ln=True) self.fc = nn.Linear(dim_hids, 1)