def __init__(self, mpi3d_toy):
        super(Net, self).__init__()
        self.mpi3d_toy = mpi3d_toy
        self.path = nn.Sequential(U.conv2d(3, 16, 6, 4, True, 64), nn.ReLU(),  # 16 x 16 x 16
                                  U.conv2d(16, 64, 4, 2, True, 16), nn.ReLU())  # 64 x 8 x 8
        self.path_a = nn.Sequential(U.conv2d(64, 16, 1, 1, True, 8), nn.ReLU(),  # 16 x 8 x 8
                                    U.Lambda(lambda x: x.reshape(-1, 1024)))  # 1024
        self.path_b = nn.Sequential(U.conv2d(64, 64, 4, 2, True, 8), nn.ReLU(),  # 64 x 4 x 4
                                    U.conv2d(64, 64, 4, 1, True, 4), nn.ReLU(),  # 64 x 4 x 4
                                    U.Lambda(lambda x: x.reshape(-1, 1024)))  # 1024
        self.path_c = nn.Sequential(U.conv2d(64, 256, 1, 1, True, 8), nn.ReLU(),  # 256 x 8 x 8
                                    nn.AdaptiveMaxPool2d((1, 1)), U.Lambda(lambda x: x.reshape(-1, 256)),  # 256
                                    nn.Linear(256, 1024), nn.ReLU())  # 1024

        self.fc1 = nn.Sequential(nn.Linear(1024, 256), nn.ReLU())
        self.output_color = nn.Sequential(nn.Linear(256, self.mpi3d_toy.factors['color']))
        self.output_shape = nn.Sequential(nn.Linear(256, 64), nn.ReLU(),
                                          nn.Linear(64, self.mpi3d_toy.factors['shape']))
        self.output_size = nn.Sequential(nn.Linear(256, 64), nn.ReLU(),
                                         nn.Linear(64, self.mpi3d_toy.factors['size']))

        self.fc2 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU())
        self.output_horizontal = nn.Sequential(nn.Linear(512, 256), nn.ReLU(),
                                               nn.Linear(256, self.mpi3d_toy.factors['horizontal']))
        self.output_vertical = nn.Sequential(nn.Linear(512, 256), nn.ReLU(),
                                             nn.Linear(256, self.mpi3d_toy.factors['vertical']))

        self.output_camera = nn.Sequential(nn.Linear(1024, self.mpi3d_toy.factors['camera']))
        self.output_background = nn.Sequential(nn.Linear(1024, self.mpi3d_toy.factors['background']))
Пример #2
0
 def __init__(self,
              n_channels,
              n_channels_sm,
              n_branches,
              ksize,
              fmap_size,
              switch_ksize,
              switch_stride,
              use_batchnorm=False):
     super(EncSwitchedConv, self).__init__()
     self.n_branches = n_branches
     self.convs = nn.ModuleList([
         nn.Sequential(
             U.conv2d(
                 n_channels, n_channels_sm, ksize, 1, in_h_or_w=fmap_size),
             nn.BatchNorm2d(n_channels_sm), nn.ReLU(),
             U.conv2d(
                 n_channels_sm, n_channels, ksize, 1, in_h_or_w=fmap_size),
             nn.BatchNorm2d(n_channels))
         if use_batchnorm else nn.Sequential(
             U.conv2d(
                 n_channels, n_channels_sm, ksize, 1, in_h_or_w=fmap_size),
             nn.ReLU(),
             U.conv2d(
                 n_channels_sm, n_channels, ksize, 1, in_h_or_w=fmap_size))
         for _ in range(n_branches)
     ])
     self.switch = nn.Sequential(
         U.conv2d(n_channels,
                  1,
                  switch_ksize,
                  switch_stride,
                  in_h_or_w=fmap_size), nn.ReLU(),
         U.Lambda(lambda x: x.reshape(-1, (fmap_size // switch_stride)**2)),
         nn.Linear((fmap_size // switch_stride)**2, 3 * n_branches))
Пример #3
0
    def __init__(self, factors):
        super(Decoder, self).__init__()
        self.factors = factors

        self.factor_embeds = nn.ParameterDict({
            'color': nn.Parameter(torch.randn(self.factors['color'], N_FACTOR_DIMS)),
            'shape': nn.Parameter(torch.randn(self.factors['shape'], N_FACTOR_DIMS)),
            'size': nn.Parameter(torch.randn(self.factors['size'], N_FACTOR_DIMS)),
            'camera': nn.Parameter(torch.randn(self.factors['camera'], N_FACTOR_DIMS)),
            'background': nn.Parameter(torch.randn(self.factors['background'], N_FACTOR_DIMS)),
            'horizontal': nn.Parameter(torch.randn(self.factors['horizontal'], N_FACTOR_DIMS)),
            'vertical': nn.Parameter(torch.randn(self.factors['vertical'], N_FACTOR_DIMS))
        })

        n_dims = N_EMBED_DIMS + N_FACTOR_DIMS
        self.input_color = nn.Linear(n_dims, 512)
        self.input_shape = nn.Linear(n_dims, 512)
        self.input_size = nn.Linear(n_dims, 512)
        self.path_col_shp_siz = nn.Sequential(nn.ReLU(), nn.Linear(512, 1024))

        self.input_horizontal = nn.Linear(n_dims, 512)
        self.input_vertical = nn.Linear(n_dims, 512)
        self.path_hor_ver = nn.Sequential(nn.ReLU(), nn.Linear(512, 1024))

        self.input_camera = nn.Sequential(nn.Linear(n_dims, 512), nn.ReLU(), nn.Linear(512, 1024))
        self.input_background = nn.Sequential(nn.Linear(n_dims, 512), nn.ReLU(), nn.Linear(512, 1024))

        self.path_shallow = nn.Sequential(nn.ReLU(), nn.Linear(1024, 1024),
                                          U.Lambda(lambda x: x.reshape(-1, 16, 8, 8)),  # 16 x 8 x 8
                                          nn.ReLU(), U.deconv2d(16, 64, 1, 1, True, 8))  # 64 x 8 x 8
        self.path_deep = nn.Sequential(nn.ReLU(), nn.Linear(1024, 1024),
                                       U.Lambda(lambda x: x.reshape(-1, 64, 4, 4)),  # 64 x 4 x 4
                                       nn.ReLU(), U.deconv2d(64, 64, 4, 1, True, 4),  # 64 x 4 x 4
                                       nn.ReLU(), U.deconv2d(64, 64, 4, 2, True, 8))  # 64 x 8 x 8
        self.path_base = nn.Sequential(nn.ReLU(), U.deconv2d(64, 16, 4, 2, True, 16),  # 16 x 16 x 16
                                       nn.ReLU(), U.deconv2d(16, 3, 6, 4, True, 64))  # 3 x 64 x 64

        self.path_shallow2 = nn.Sequential(nn.ReLU(), nn.Linear(1024, 1024),
                                           U.Lambda(lambda x: x.reshape(-1, 16, 8, 8)),  # 16 x 8 x 8
                                           nn.ReLU(), U.deconv2d(16, 64, 2, 2, True, 16))  # 64 x 16 x 16
        self.path_deep2 = nn.Sequential(nn.ReLU(), nn.Linear(1024, 1024),
                                        U.Lambda(lambda x: x.reshape(-1, 64, 4, 4)),  # 64 x 4 x 4
                                        nn.ReLU(), U.deconv2d(64, 64, 4, 2, True, 8),  # 64 x 8 x 8
                                        nn.ReLU(), U.deconv2d(64, 64, 4, 2, True, 16))  # 64 x 16 x 16
        self.path_base2 = nn.Sequential(nn.ReLU(), U.deconv2d(64, 16, 2, 2, True, 32),  # 16 x 32 x 32
                                        nn.ReLU(), U.deconv2d(16, 3, 2, 2, True, 64))  # 3 x 64 x 64
Пример #4
0
    def __init__(self, factors):
        super(Encoder, self).__init__()
        self.factors = factors
        self.path_base = nn.Sequential(U.conv2d(3, 16, 6, 4, True, 64), nn.ReLU(),  # 16 x 16 x 16
                                       U.conv2d(16, 64, 4, 2, True, 16), nn.ReLU())  # 64 x 8 x 8
        self.path_shallow = nn.Sequential(U.conv2d(64, 16, 1, 1, True, 8), nn.ReLU(),  # 16 x 8 x 8
                                          U.Lambda(lambda x: x.reshape(-1, 1024)),  # 1024
                                          nn.Linear(1024, 1024), nn.ReLU())
        self.path_deep = nn.Sequential(U.conv2d(64, 64, 4, 2, True, 8), nn.ReLU(),  # 64 x 4 x 4
                                       U.conv2d(64, 64, 4, 1, True, 4), nn.ReLU(),  # 64 x 4 x 4
                                       U.Lambda(lambda x: x.reshape(-1, 1024)),  # 1024
                                       nn.Linear(1024, 1024), nn.ReLU())
        self.path_pool = nn.Sequential(U.conv2d(64, 256, 1, 1, True, 8), nn.ReLU(),  # 256 x 8 x 8
                                       nn.AdaptiveMaxPool2d((1, 1)), U.Lambda(lambda x: x.reshape(-1, 256)),  # 256
                                       nn.Linear(256, 1024), nn.ReLU())  # 1024

        self.path_base2 = nn.Sequential(U.conv2d(3, 16, 2, 2, True, 64), nn.ReLU(),  # 16 x 32 x 32
                                        U.conv2d(16, 64, 2, 2, True, 16), nn.ReLU())  # 64 x 16 x 16
        self.path_shallow2 = nn.Sequential(nn.MaxPool2d(2, 2),  # 64 x 8 x 8
                                           U.conv2d(64, 16, 1, 1, True, 8), nn.ReLU(),  # 16 x 8 x 8
                                           U.Lambda(lambda x: x.reshape(-1, 1024)),
                                           nn.Linear(1024, 1024), nn.ReLU())  # 1024
        self.path_deep2 = nn.Sequential(U.conv2d(64, 64, 4, 2, True, 16), nn.ReLU(),  # 64 x 8 x 8
                                        U.conv2d(64, 64, 4, 2, True, 8), nn.ReLU(),  # 64 x 4 x 4
                                        U.Lambda(lambda x: x.reshape(-1, 1024)),
                                        nn.Linear(1024, 1024), nn.ReLU())  # 1024
        self.path_pool2 = nn.Sequential(U.conv2d(64, 256, 2, 2, True, 16), nn.ReLU(),  # 256 x 8 x 8
                                        nn.AdaptiveMaxPool2d((1, 1)), U.Lambda(lambda x: x.reshape(-1, 256)),  # 256
                                        nn.Linear(256, 1024), nn.ReLU())  # 1024

        self.weight_col_shp_siz = [0.5, 0.0, 1.5, 1.0, 0.0, 2.0]
        self.path_col_shp_siz_shallow = nn.Sequential(nn.Linear(1024, 256), nn.ReLU())
        self.path_col_shp_siz_deep = nn.Sequential(nn.Linear(1024, 512), nn.ReLU(),
                                                   nn.Linear(512, 256), nn.ReLU())
        self.path_col_shp_siz_deeper = nn.Sequential(nn.Linear(1024, 512), nn.ReLU(),
                                                     nn.Linear(512, 256), nn.ReLU(),
                                                     nn.Linear(256, 256), nn.ReLU())

        self.weight_col = [1.0, 0.0, 0.5]
        self.weight_shp = [0.0, 0.5, 1.0]
        self.weight_siz = [0.0, 1.0, 0.5]
        self.output_color = nn.ModuleDict({'code_logits': nn.Linear(256, self.factors['color']),
                                           'embed_mu': nn.Linear(256, N_EMBED_DIMS),
                                           'embed_logvar': nn.Linear(256, N_EMBED_DIMS)})
        self.output_shape = nn.ModuleDict({'code_logits': nn.Linear(256, self.factors['shape']),
                                           'embed_mu': nn.Linear(256, N_EMBED_DIMS),
                                           'embed_logvar': nn.Linear(256, N_EMBED_DIMS)})
        self.output_size = nn.ModuleDict({'code_logits': nn.Linear(256, self.factors['size']),
                                          'embed_mu': nn.Linear(256, N_EMBED_DIMS),
                                          'embed_logvar': nn.Linear(256, N_EMBED_DIMS)})

        self.weight_hor_ver = [1.0, 0.0, 0.0, 2.0, 1.5, 0.5]
        self.path_hor_ver_shallow = nn.Sequential(nn.Linear(1024, 256), nn.ReLU())
        self.path_hor_ver_deep = nn.Sequential(nn.Linear(1024, 512), nn.ReLU(),
                                               nn.Linear(512, 256), nn.ReLU())
        self.path_hor_ver_deeper = nn.Sequential(nn.Linear(1024, 512), nn.ReLU(),
                                                 nn.Linear(512, 256), nn.ReLU(),
                                                 nn.Linear(256, 256), nn.ReLU())

        self.weight_hor = [0.5, 0.0, 1.0]
        self.weight_ver = [0.5, 1.0, 0.0]
        self.output_horizontal = nn.ModuleDict({'code_logits': nn.Linear(256, self.factors['horizontal']),
                                                'embed_mu': nn.Linear(256, N_EMBED_DIMS),
                                                'embed_logvar': nn.Linear(256, N_EMBED_DIMS)})
        self.output_vertical = nn.ModuleDict({'code_logits': nn.Linear(256, self.factors['vertical']),
                                              'embed_mu': nn.Linear(256, N_EMBED_DIMS),
                                              'embed_logvar': nn.Linear(256, N_EMBED_DIMS)})

        self.weight_cam = [0.0, 0.5, 0.0, 1.0, 1.5, 0.0]
        self.weight_bkg = [0.5, 1.5, 0.0, 0.0, 0.0, 1.0]
        self.output_camera = nn.ModuleDict({'code_logits': nn.Linear(1024, self.factors['camera']),
                                            'embed_mu': nn.Linear(1024, N_EMBED_DIMS),
                                            'embed_logvar': nn.Linear(1024, N_EMBED_DIMS)})
        self.output_background = nn.ModuleDict({'code_logits': nn.Linear(1024, self.factors['background']),
                                                'embed_mu': nn.Linear(1024, N_EMBED_DIMS),
                                                'embed_logvar': nn.Linear(1024, N_EMBED_DIMS)})
Пример #5
0
    def __init__(self, mpi3d_toy):
        super(Net, self).__init__()
        self.mpi3d_toy = mpi3d_toy
        self.path_base = nn.Sequential(
            U.conv2d(3, 16, 6, 4, True, 64),
            nn.ReLU(),  # 16 x 16 x 16
            U.conv2d(16, 64, 4, 2, True, 16),
            nn.ReLU())  # 64 x 8 x 8
        self.path_shallow = nn.Sequential(
            U.conv2d(64, 16, 1, 1, True, 8),
            nn.ReLU(),  # 16 x 8 x 8
            U.Lambda(lambda x: x.reshape(-1, 1024)),
            nn.Linear(1024, 1024),
            nn.ReLU())  # 1024
        self.path_deep = nn.Sequential(
            U.conv2d(64, 64, 4, 2, True, 8),
            nn.ReLU(),  # 64 x 4 x 4
            U.conv2d(64, 64, 4, 1, True, 4),
            nn.ReLU(),  # 64 x 4 x 4
            U.Lambda(lambda x: x.reshape(-1, 1024)),
            nn.Linear(1024, 1024),
            nn.ReLU())  # 1024
        self.path_pool = nn.Sequential(
            U.conv2d(64, 256, 1, 1, True, 8),
            nn.ReLU(),  # 256 x 8 x 8
            nn.AdaptiveMaxPool2d((1, 1)),
            U.Lambda(lambda x: x.reshape(-1, 256)),  # 256
            nn.Linear(256, 1024),
            nn.ReLU())  # 1024

        self.path_base2 = nn.Sequential(
            U.conv2d(3, 16, 2, 2, True, 64),
            nn.ReLU(),  # 16 x 32 x 32
            U.conv2d(16, 64, 2, 2, True, 16),
            nn.ReLU())  # 64 x 16 x 16
        self.path_shallow2 = nn.Sequential(
            nn.MaxPool2d(2, 2),  # 64 x 8 x 8
            U.conv2d(64, 16, 1, 1, True, 8),
            nn.ReLU(),  # 16 x 8 x 8
            U.Lambda(lambda x: x.reshape(-1, 1024)),
            nn.Linear(1024, 1024),
            nn.ReLU())  # 1024
        self.path_deep2 = nn.Sequential(
            U.conv2d(64, 64, 4, 2, True, 16),
            nn.ReLU(),  # 64 x 8 x 8
            U.conv2d(64, 64, 4, 2, True, 8),
            nn.ReLU(),  # 64 x 4 x 4
            U.Lambda(lambda x: x.reshape(-1, 1024)),
            nn.Linear(1024, 1024),
            nn.ReLU())  # 1024
        self.path_pool2 = nn.Sequential(
            U.conv2d(64, 256, 2, 2, True, 16),
            nn.ReLU(),  # 256 x 8 x 8
            nn.AdaptiveMaxPool2d((1, 1)),
            U.Lambda(lambda x: x.reshape(-1, 256)),  # 256
            nn.Linear(256, 1024),
            nn.ReLU())  # 1024

        self.weight_col_shp_siz = nn.Parameter(torch.ones(6))
        self.path_col_shp_siz_shallow = nn.Sequential(nn.Linear(1024, 256),
                                                      nn.ReLU())
        self.path_col_shp_siz_deep = nn.Sequential(nn.Linear(1024, 512),
                                                   nn.ReLU(),
                                                   nn.Linear(512, 256),
                                                   nn.ReLU())
        self.path_col_shp_siz_deeper = nn.Sequential(nn.Linear(1024, 512),
                                                     nn.ReLU(),
                                                     nn.Linear(512, 256),
                                                     nn.ReLU(),
                                                     nn.Linear(256, 256),
                                                     nn.ReLU())

        self.weight_col = nn.Parameter(torch.ones(3))
        self.weight_shp = nn.Parameter(torch.ones(3))
        self.weight_siz = nn.Parameter(torch.ones(3))
        self.output_color = nn.Sequential(
            nn.Linear(256, self.mpi3d_toy.factors['color']))
        self.output_shape = nn.Sequential(
            nn.Linear(256, self.mpi3d_toy.factors['shape']))
        self.output_size = nn.Sequential(
            nn.Linear(256, self.mpi3d_toy.factors['size']))

        self.weight_hor_ver = nn.Parameter(torch.ones(6))
        self.path_hor_ver_shallow = nn.Sequential(nn.Linear(1024, 256),
                                                  nn.ReLU())
        self.path_hor_ver_deep = nn.Sequential(nn.Linear(1024, 512), nn.ReLU(),
                                               nn.Linear(512, 256), nn.ReLU())
        self.path_hor_ver_deeper = nn.Sequential(nn.Linear(1024,
                                                           512), nn.ReLU(),
                                                 nn.Linear(512,
                                                           256), nn.ReLU(),
                                                 nn.Linear(256, 256),
                                                 nn.ReLU())

        self.weight_hor = nn.Parameter(torch.ones(3))
        self.weight_ver = nn.Parameter(torch.ones(3))
        self.output_horizontal = nn.Sequential(
            nn.Linear(256, self.mpi3d_toy.factors['horizontal']))
        self.output_vertical = nn.Sequential(
            nn.Linear(256, self.mpi3d_toy.factors['vertical']))

        self.weight_cam = nn.Parameter(torch.ones(6))
        self.weight_bkg = nn.Parameter(torch.ones(6))
        self.output_camera = nn.Sequential(
            nn.Linear(1024, 256), nn.ReLU(),
            nn.Linear(256, self.mpi3d_toy.factors['camera']))
        self.output_background = nn.Sequential(
            nn.Linear(1024, 256), nn.ReLU(),
            nn.Linear(256, self.mpi3d_toy.factors['background']))
Пример #6
0
    def __init__(self):
        super(Encoder, self).__init__()
        self.factors = {
            'color': 4,
            'shape': 4,
            'size': 2,
            'camera': 3,
            'background': 3,
            'horizontal': 40,
            'vertical': 40
        }
        self.path_base = nn.Sequential(
            U.conv2d(3, 16, 6, 4, True, 64),
            nn.ReLU(),  # 16 x 16 x 16
            U.conv2d(16, 64, 4, 2, True, 16),
            nn.ReLU())  # 64 x 8 x 8
        self.path_shallow = nn.Sequential(
            U.conv2d(64, 16, 1, 1, True, 8),
            nn.ReLU(),  # 16 x 8 x 8
            U.Lambda(lambda x: x.reshape(-1, 1024)),  # 1024
            nn.Linear(1024, 1024),
            nn.ReLU())
        self.path_deep = nn.Sequential(
            U.conv2d(64, 64, 4, 2, True, 8),
            nn.ReLU(),  # 64 x 4 x 4
            U.conv2d(64, 64, 4, 1, True, 4),
            nn.ReLU(),  # 64 x 4 x 4
            U.Lambda(lambda x: x.reshape(-1, 1024)),  # 1024
            nn.Linear(1024, 1024),
            nn.ReLU())
        self.path_pool = nn.Sequential(
            U.conv2d(64, 256, 1, 1, True, 8),
            nn.ReLU(),  # 256 x 8 x 8
            nn.AdaptiveMaxPool2d((1, 1)),
            U.Lambda(lambda x: x.reshape(-1, 256)),  # 256
            nn.Linear(256, 1024),
            nn.ReLU())  # 1024

        self.weight_col_shp_siz = [0.95, 0.23, 1.33]
        self.path_col_shp_siz_shallow = nn.Sequential(nn.Linear(1024, 256),
                                                      nn.ReLU())
        self.path_col_shp_siz_deep = nn.Sequential(nn.Linear(1024, 512),
                                                   nn.ReLU(),
                                                   nn.Linear(512, 256),
                                                   nn.ReLU())
        self.path_col_shp_siz_deeper = nn.Sequential(nn.Linear(1024, 512),
                                                     nn.ReLU(),
                                                     nn.Linear(512, 256),
                                                     nn.ReLU(),
                                                     nn.Linear(256, 256),
                                                     nn.ReLU())

        self.weight_col = [1.12, 0.53, 0.29]
        self.weight_shp = [0.34, 0.26, 0.29]
        self.weight_siz = [0.36, 0.36, 0.49]
        self.output_color = nn.ModuleDict({
            'code_logits':
            nn.Linear(256, self.factors['color']),
            'embed_mu':
            nn.Linear(256, 32),
            'embed_logvar':
            nn.Linear(256, 32)
        })
        self.output_shape = nn.ModuleDict({
            'code_logits':
            nn.Linear(256, self.factors['shape']),
            'embed_mu':
            nn.Linear(256, 32),
            'embed_logvar':
            nn.Linear(256, 32)
        })
        self.output_size = nn.ModuleDict({
            'code_logits':
            nn.Linear(256, self.factors['size']),
            'embed_mu':
            nn.Linear(256, 32),
            'embed_logvar':
            nn.Linear(256, 32)
        })

        self.weight_hor_ver = [1.36, 0.78, 0.67]
        self.path_hor_ver_shallow = nn.Sequential(nn.Linear(1024, 256),
                                                  nn.ReLU())
        self.path_hor_ver_deep = nn.Sequential(nn.Linear(1024, 512), nn.ReLU(),
                                               nn.Linear(512, 256), nn.ReLU())
        self.path_hor_ver_deeper = nn.Sequential(nn.Linear(1024,
                                                           512), nn.ReLU(),
                                                 nn.Linear(512,
                                                           256), nn.ReLU(),
                                                 nn.Linear(256, 256),
                                                 nn.ReLU())

        self.weight_hor = [0.24, 0.0, 0.14]
        self.weight_ver = [0.30, 0.58, 0.0]
        self.output_horizontal = nn.ModuleDict({
            'code_logits':
            nn.Linear(256, self.factors['horizontal']),
            'embed_mu':
            nn.Linear(256, 32),
            'embed_logvar':
            nn.Linear(256, 32)
        })
        self.output_vertical = nn.ModuleDict({
            'code_logits':
            nn.Linear(256, self.factors['vertical']),
            'embed_mu':
            nn.Linear(256, 32),
            'embed_logvar':
            nn.Linear(256, 32)
        })

        self.weight_cam = [1.35, 1.15, 0.56]
        self.weight_bkg = [1.11, 0.0, 1.54]
        self.output_camera = nn.ModuleDict({
            'code_logits':
            nn.Linear(1024, self.factors['camera']),
            'embed_mu':
            nn.Linear(1024, 32),
            'embed_logvar':
            nn.Linear(1024, 32)
        })
        self.output_background = nn.ModuleDict({
            'code_logits':
            nn.Linear(1024, self.factors['background']),
            'embed_mu':
            nn.Linear(1024, 32),
            'embed_logvar':
            nn.Linear(1024, 32)
        })
Пример #7
0
    def __init__(self):
        super(Decoder, self).__init__()
        self.factors = {
            'color': 4,
            'shape': 4,
            'size': 2,
            'camera': 3,
            'background': 3,
            'horizontal': 40,
            'vertical': 40
        }

        self.mu = nn.ParameterDict({
            'color':
            nn.Parameter(torch.randn(self.factors['color'], 128)),
            'shape':
            nn.Parameter(torch.randn(self.factors['shape'], 128)),
            'size':
            nn.Parameter(torch.randn(self.factors['size'], 128)),
            'camera':
            nn.Parameter(torch.randn(self.factors['camera'], 128)),
            'background':
            nn.Parameter(torch.randn(self.factors['background'], 128)),
            'horizontal':
            nn.Parameter(torch.randn(self.factors['horizontal'], 128)),
            'vertical':
            nn.Parameter(torch.randn(self.factors['vertical'], 128))
        })
        self.logvar = nn.ParameterDict({
            'color':
            nn.Parameter(torch.zeros(self.factors['color'], 128)),
            'shape':
            nn.Parameter(torch.zeros(self.factors['shape'], 128)),
            'size':
            nn.Parameter(torch.zeros(self.factors['size'], 128)),
            'camera':
            nn.Parameter(torch.zeros(self.factors['camera'], 128)),
            'background':
            nn.Parameter(torch.zeros(self.factors['background'], 128)),
            'horizontal':
            nn.Parameter(torch.zeros(self.factors['horizontal'], 128)),
            'vertical':
            nn.Parameter(torch.zeros(self.factors['vertical'], 128))
        })

        self.input_color = nn.Linear(160, 256)
        self.input_shape = nn.Linear(160, 256)
        self.input_size = nn.Linear(160, 256)
        self.path_col_shp_siz = nn.Sequential(nn.ReLU(), nn.Linear(256, 1024))

        self.input_horizontal = nn.Linear(160, 256)
        self.input_vertical = nn.Linear(160, 256)
        self.path_hor_ver = nn.Sequential(nn.ReLU(), nn.Linear(256, 1024))

        self.input_camera = nn.Linear(160, 1024)
        self.input_background = nn.Linear(160, 1024)

        self.path_shallow = nn.Sequential(
            nn.ReLU(),
            nn.Linear(1024, 1024),
            U.Lambda(lambda x: x.reshape(-1, 16, 8, 8)),  # 16 x 8 x 8
            nn.ReLU(),
            U.deconv2d(16, 64, 1, 1, True, 8))  # 64 x 8 x 8
        self.path_deep = nn.Sequential(
            nn.ReLU(),
            nn.Linear(1024, 1024),
            U.Lambda(lambda x: x.reshape(-1, 64, 4, 4)),  # 64 x 4 x 4
            nn.ReLU(),
            U.deconv2d(64, 64, 4, 1, True, 4),  # 64 x 4 x 4
            nn.ReLU(),
            U.deconv2d(64, 64, 4, 2, True, 8))  # 64 x 8 x 8
        self.path_base = nn.Sequential(
            nn.ReLU(),
            U.deconv2d(64, 16, 4, 2, True, 16),  # 16 x 16 x 16
            nn.ReLU(),
            U.deconv2d(16, 3, 6, 4, True, 64))  # 3 x 64 x 64