def __init__(self, img_ch, n_ctx, n_hid=64, n_z=10, enc_dim=512, share_prior_enc=False, reverse_post=False, ): super().__init__() self.n_ctx = n_ctx self.enc_dim = enc_dim self.emb_net = nn.ModuleList([ nn.Conv2d(img_ch, n_hid, 1), ResnetBlock(n_hid, n_hid), nn.MaxPool2d(2, 2), ResnetBlock(n_hid*1, n_hid*2), ResnetBlock(n_hid*2, n_hid*2), nn.MaxPool2d(2, 2), ResnetBlock(n_hid*2, n_hid*4), ResnetBlock(n_hid*4, n_hid*4), nn.MaxPool2d(2, 2), ResnetBlock(n_hid*4, n_hid*4), ResnetBlock(n_hid*4, n_hid*8), nn.MaxPool2d(2, 2), ResnetBlock(n_hid*8, n_hid*8), ResnetBlock(n_hid*8, n_hid*8), nn.MaxPool2d(4, 1), ResnetBlock(n_hid*8, n_hid*8, norm_ch=1), ResnetBlock(n_hid*8, n_hid*8, norm_ch=1), ]) mult = 1 self.render_net = nn.ModuleList([ layers.ConvLSTM(n_hid*8 + n_hid*8, n_hid*8), layers.DcUpConv(n_hid*8, n_hid*8, 4, 1, 0), layers.ConvLSTM(n_hid*8, n_hid*8, norm=True), layers.DcUpConv(n_hid*8*mult, n_hid*8, 4, 2, 1), layers.ConvLSTM(n_hid*8 + n_hid*8, n_hid*8, norm=True), layers.DcUpConv(n_hid*8*mult, n_hid*4, 4, 2, 1), layers.ConvLSTM(n_hid*4, n_hid*4, norm=True), layers.DcUpConv(n_hid*4*mult, n_hid*2, 4, 2, 1), layers.ConvLSTM(n_hid*2 + n_hid*2, n_hid*2, norm=True), layers.DcUpConv(n_hid*2*mult, n_hid, 4, 2, 1), layers.ConvLSTM(n_hid, n_hid, norm=True), layers.DcConv(n_hid, n_hid, 3, 1, 1), layers.TemporalConv2d(n_hid, img_ch, 3, 1, 1), ]) self.det_init_net = nn.Sequential( layers.DcConv(2*n_hid*8*self.n_ctx, 2*n_hid*8*self.n_ctx, 1), layers.TemporalConv2d(2*n_hid*8*self.n_ctx, 2*n_hid*8, 1), layers.TemporalNorm2d(1, 2*enc_dim), ) self.prior_init_nets = nn.ModuleDict({ 'layer_16': nn.Sequential( layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1), layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1), layers.TemporalNorm2d(1, 2*n_hid*8), ), 'layer_10': nn.Sequential( layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1), layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1), layers.TemporalNorm2d(16, 2*n_hid*8), ), 'layer_4': nn.Sequential( layers.DcConv(n_hid*2*self.n_ctx, n_hid*2*self.n_ctx, 1), layers.TemporalConv2d(n_hid*2*self.n_ctx, n_hid*2*2, 1), layers.TemporalNorm2d(16, 2*n_hid*2), ), }) self.posterior_init_nets = nn.ModuleDict({ 'layer_16': nn.Sequential( layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1), layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1), layers.TemporalNorm2d(1, 2*n_hid*8), ), 'layer_10': nn.Sequential( layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1), layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1), layers.TemporalNorm2d(16, 2*n_hid*8), ), 'layer_4': nn.Sequential( layers.DcConv(n_hid*2*self.n_ctx, n_hid*2*self.n_ctx, 1), layers.TemporalConv2d(n_hid*2*self.n_ctx, n_hid*2*2, 1), layers.TemporalNorm2d(16, 2*n_hid*2), ), }) self.posterior_branches = nn.ModuleDict({ 'layer_4': nn.ModuleList([ layers.TemporalConv2d(n_hid*2, n_hid*2, 1), layers.TemporalNorm2d(16, n_hid*2), layers.ConvLSTM(n_hid*2, n_hid*2, norm=True), layers.TemporalConv2d(n_hid*2 + n_hid*8 + n_hid*8, n_hid*2*2, 1), layers.TemporalNorm2d(16, n_hid*2*2), ]), 'layer_10': nn.ModuleList([ layers.TemporalConv2d(n_hid*8, n_hid*8, 1), layers.TemporalNorm2d(16, n_hid*8), layers.ConvLSTM(n_hid*8, n_hid*8, norm=True), layers.TemporalConv2d(n_hid*8 + n_hid*8, n_hid*8*2, 1), layers.TemporalNorm2d(16, n_hid*8*2), ]), 'layer_16': nn.ModuleList([ layers.TemporalConv2d(n_hid*8, n_hid*8, 1), layers.TemporalNorm2d(1, n_hid*8), layers.ConvLSTM(n_hid*8, n_hid*8, norm=True), layers.TemporalConv2d(n_hid*8, n_hid*8*2, 1), layers.TemporalNorm2d(1, n_hid*8*2), ]), }) self.prior_branches = nn.ModuleDict({ 'layer_4': nn.ModuleList([ layers.TemporalConv2d(n_hid*2, n_hid*2, 1), layers.TemporalNorm2d(16, n_hid*2), layers.ConvLSTM(n_hid*2, n_hid*2, norm=True), layers.TemporalConv2d(n_hid*2 + n_hid*8 + n_hid*8, n_hid*2*2, 1), layers.TemporalNorm2d(16, n_hid*2*2), ]), 'layer_10': nn.ModuleList([ layers.TemporalConv2d(n_hid*8, n_hid*8, 1), layers.TemporalNorm2d(16, n_hid*8), layers.ConvLSTM(n_hid*8, n_hid*8, norm=True), layers.TemporalConv2d(n_hid*8 + n_hid*8, n_hid*8*2, 1), layers.TemporalNorm2d(16, n_hid*8*2), ]), 'layer_16': nn.ModuleList([ layers.TemporalConv2d(n_hid*8, n_hid*8, 1), layers.TemporalNorm2d(1, n_hid*8), layers.ConvLSTM(n_hid*8, n_hid*8, norm=True), layers.TemporalConv2d(n_hid*8, n_hid*8*2, 1), layers.TemporalNorm2d(1, n_hid*8*2), ]), }) # Prior/Posterior branches norm init nn.init.constant_(self.posterior_branches['layer_4'][-1].model.weight, 0) nn.init.normal_(self.posterior_branches['layer_4'][-1].model.bias, std=1e-3) nn.init.constant_(self.posterior_branches['layer_10'][-1].model.weight, 0) nn.init.normal_(self.posterior_branches['layer_10'][-1].model.bias, std=1e-3) nn.init.constant_(self.posterior_branches['layer_16'][-1].model.weight, 0) nn.init.normal_(self.posterior_branches['layer_16'][-1].model.bias, std=1e-3) nn.init.constant_(self.prior_branches['layer_4'][-1].model.weight, 0) nn.init.normal_(self.prior_branches['layer_4'][-1].model.bias, std=1e-3) nn.init.constant_(self.prior_branches['layer_10'][-1].model.weight, 0) nn.init.normal_(self.prior_branches['layer_10'][-1].model.bias, std=1e-3) nn.init.constant_(self.prior_branches['layer_16'][-1].model.weight, 0) nn.init.normal_(self.prior_branches['layer_16'][-1].model.bias, std=1e-3) # Connection list self.det_init_connections = { 0: 16, 2: 13, 4: 10, 6: 7, 8: 4, 10: 1, } # Connection branches self.det_init_nets = nn.ModuleDict({ 'layer_16': nn.Sequential( layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1), layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1), layers.TemporalNorm2d(1, n_hid*8*2) ), 'layer_13': nn.Sequential( layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 3, 1, 1), layers.TemporalConv2d(self.n_ctx*n_hid*8, 2*n_hid*8, 1), layers.TemporalNorm2d(16, n_hid*8*2) ), 'layer_10': nn.Sequential( layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1), layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1), layers.TemporalNorm2d(16, n_hid*8*2) ), 'layer_7': nn.Sequential( layers.DcConv(n_hid*4*self.n_ctx, n_hid*4*self.n_ctx, 1), layers.TemporalConv2d(n_hid*4*self.n_ctx, n_hid*4*2, 1), layers.TemporalNorm2d(16, n_hid*8) ), 'layer_4': nn.Sequential( layers.DcConv(n_hid*2*self.n_ctx, n_hid*2*self.n_ctx, 1), layers.TemporalConv2d(n_hid*2*self.n_ctx, n_hid*2*2, 1), layers.TemporalNorm2d(16, n_hid*4) ), 'layer_1': nn.Sequential( layers.DcConv(n_hid*1*self.n_ctx, n_hid*1*self.n_ctx, 1), layers.TemporalConv2d(n_hid*1*self.n_ctx, n_hid*1*2, 1), layers.TemporalNorm2d(16, n_hid*2) ), }) # Stochastic connection list # encoder -> renderer self.sto_branches = { 16: 0, 10: 4, 4: 8, } # renderer -> encoder self.rend_sto_branches = { 0: 0, 4: 1, 8: 2, }
def __init__( self, img_ch, n_ctx, n_hid=64, n_z=10, enc_dim=512, share_prior_enc=False, reverse_post=False, ): super().__init__() self.n_ctx = n_ctx self.enc_dim = enc_dim self.sto_emb_net = nn.ModuleList([ layers.DcConv(img_ch, n_hid, 4, 2, 1), layers.DcConv(n_hid, n_hid * 2, 4, 2, 1), layers.DcConv(n_hid * 2, n_hid * 4, 4, 2, 1), layers.DcConv(n_hid * 4, n_hid * 8, 4, 2, 1), layers.DcConv(n_hid * 8, enc_dim, 4, 1, 0, norm=partial(nn.GroupNorm, 1)), ]) self.det_emb_net = nn.ModuleList([ nn.Conv2d(img_ch, n_hid, 1), ResnetBlock(n_hid, n_hid), nn.MaxPool2d(2, 2), ResnetBlock(n_hid * 1, n_hid * 2), ResnetBlock(n_hid * 2, n_hid * 2), nn.MaxPool2d(2, 2), ResnetBlock(n_hid * 2, n_hid * 4), ResnetBlock(n_hid * 4, n_hid * 4), nn.MaxPool2d(2, 2), ResnetBlock(n_hid * 4, n_hid * 4), ResnetBlock(n_hid * 4, n_hid * 8), nn.MaxPool2d(2, 2), ResnetBlock(n_hid * 8, n_hid * 8), ResnetBlock(n_hid * 8, n_hid * 8), nn.MaxPool2d(4, 1), ResnetBlock(n_hid * 8, n_hid * 8, norm_ch=1), ResnetBlock(n_hid * 8, n_hid * 8, norm_ch=1), ]) mult = 1 self.render_net = nn.ModuleList([ layers.ConvLSTM(n_z + enc_dim, enc_dim), layers.DcUpConv(enc_dim, n_hid * 8, 4, 1, 0), layers.ConvLSTM(n_hid * 8, n_hid * 8, norm=True), layers.DcUpConv(n_hid * 8 * mult, n_hid * 8, 4, 2, 1), layers.ConvLSTM(n_hid * 8, n_hid * 8, norm=True), layers.DcUpConv(n_hid * 8 * mult, n_hid * 4, 4, 2, 1), layers.ConvLSTM(n_hid * 4, n_hid * 4, norm=True), layers.DcUpConv(n_hid * 4 * mult, n_hid * 2, 4, 2, 1), layers.ConvLSTM(n_hid * 2, n_hid * 2, norm=True), layers.DcUpConv(n_hid * 2 * mult, n_hid, 4, 2, 1), layers.ConvLSTM(n_hid, n_hid, norm=True), layers.DcConv(n_hid, n_hid, 3, 1, 1), layers.TemporalConv2d(n_hid, img_ch, 3, 1, 1), ]) self.det_init_net = nn.Sequential( layers.DcConv(2 * enc_dim * self.n_ctx, 2 * enc_dim * self.n_ctx, 1), layers.TemporalConv2d(2 * enc_dim * self.n_ctx, 2 * enc_dim, 1), layers.TemporalNorm2d(1, 2 * enc_dim), ) self.prior_init_nets = nn.ModuleDict({ 'layer_4': nn.Sequential( layers.DcConv(enc_dim * self.n_ctx, enc_dim * self.n_ctx, 1), layers.TemporalConv2d(enc_dim * self.n_ctx, enc_dim * 2, 1), layers.TemporalNorm2d(1, 2 * enc_dim), ), }) self.posterior_init_nets = nn.ModuleDict({ 'layer_4': nn.Sequential( layers.DcConv(enc_dim * self.n_ctx, enc_dim * self.n_ctx, 1), layers.TemporalConv2d(enc_dim * self.n_ctx, enc_dim * 2, 1), layers.TemporalNorm2d(1, 2 * enc_dim), ), }) self.posterior_branches = nn.ModuleDict({ 'layer_4': nn.ModuleList([ layers.TemporalConv2d(enc_dim, n_z, 1), layers.TemporalNorm2d(1, n_z), layers.ConvLSTM(n_z, enc_dim, norm=True), layers.TemporalConv2d(enc_dim, n_z * 2, 1), ]) }) self.prior_branches = nn.ModuleDict({ 'layer_4': nn.ModuleList([ layers.TemporalConv2d(enc_dim, n_z, 1), layers.TemporalNorm2d(1, n_z), layers.ConvLSTM(n_z, enc_dim, norm=True), layers.TemporalConv2d(enc_dim, n_z * 2, 1), ]) }) # Connection list self.det_init_connections = { 0: 16, 2: 13, 4: 10, 6: 7, 8: 4, 10: 1, } # Connection branches self.det_init_nets = nn.ModuleDict({ 'layer_16': nn.Sequential( layers.DcConv(n_hid * 8 * self.n_ctx, n_hid * 8 * self.n_ctx, 1), layers.TemporalConv2d(n_hid * 8 * self.n_ctx, n_hid * 8 * 2, 1), layers.TemporalNorm2d(1, n_hid * 8 * 2)), 'layer_13': nn.Sequential( layers.DcConv(n_hid * 8 * self.n_ctx, n_hid * 8 * self.n_ctx, 3, 1, 1), layers.TemporalConv2d(self.n_ctx * n_hid * 8, 2 * n_hid * 8, 1), layers.TemporalNorm2d(16, n_hid * 8 * 2)), 'layer_10': nn.Sequential( layers.DcConv(n_hid * 8 * self.n_ctx, n_hid * 8 * self.n_ctx, 1), layers.TemporalConv2d(n_hid * 8 * self.n_ctx, n_hid * 8 * 2, 1), layers.TemporalNorm2d(16, n_hid * 8 * 2)), 'layer_7': nn.Sequential( layers.DcConv(n_hid * 4 * self.n_ctx, n_hid * 4 * self.n_ctx, 1), layers.TemporalConv2d(n_hid * 4 * self.n_ctx, n_hid * 4 * 2, 1), layers.TemporalNorm2d(16, n_hid * 8)), 'layer_4': nn.Sequential( layers.DcConv(n_hid * 2 * self.n_ctx, n_hid * 2 * self.n_ctx, 1), layers.TemporalConv2d(n_hid * 2 * self.n_ctx, n_hid * 2 * 2, 1), layers.TemporalNorm2d(16, n_hid * 4)), 'layer_1': nn.Sequential( layers.DcConv(n_hid * 1 * self.n_ctx, n_hid * 1 * self.n_ctx, 1), layers.TemporalConv2d(n_hid * 1 * self.n_ctx, n_hid * 1 * 2, 1), layers.TemporalNorm2d(16, n_hid * 2)), }) # Stochastic connection list # encoder -> renderer self.sto_branches = { 4: 0, } # renderer -> encoder self.rend_sto_branches = { 0: 0, }