Ejemplo n.º 1
0
    def __init__(self):
        super().__init__()

        # CONV VERSION
        self.img_h = 16
        inter_h = 4
        inter_ch = 32
        z_ch = 32
        z_dim = inter_h * inter_h * z_ch
        self.inter_h = inter_h
        self.inter_ch = inter_ch

        self.encoder = make_conv_net(
            3, self.img_h, {
                'kernel_sizes': [4, 4],
                'num_channels': [inter_ch, inter_ch],
                'strides': [2, 2],
                'paddings': [1, 1],
                'use_bn': True,
            })[0]
        # self.z_mean_fc = nn.Linear(inter_h*inter_h*inter_ch, z_dim, bias=True)
        # self.z_log_cov_fc = nn.Linear(inter_h*inter_h*inter_ch, z_dim, bias=True)
        self.z_mean_conv = nn.Conv2d(inter_ch,
                                     z_ch,
                                     3,
                                     stride=1,
                                     padding=1,
                                     bias=True)
        self.z_log_cov_conv = nn.Conv2d(inter_ch,
                                        z_ch,
                                        3,
                                        stride=1,
                                        padding=1,
                                        bias=True)

        # self.decoder_fc = nn.Linear(z_dim, inter_h*inter_h*inter_ch)
        self.decoder = make_upconv_net(
            z_ch, inter_h, {
                'kernel_sizes': [4, 4],
                'num_channels': [inter_ch, inter_ch],
                'strides': [2, 2],
                'paddings': [1, 1],
                'output_paddings': [0, 0],
                'use_bn': True
            })[0]
        self.recon_mean_conv = nn.Conv2d(inter_ch,
                                         3,
                                         1,
                                         stride=1,
                                         padding=0,
                                         bias=True)
        self.recon_log_cov_conv = nn.Conv2d(inter_ch,
                                            3,
                                            1,
                                            stride=1,
                                            padding=0,
                                            bias=True)

        print(self.encoder)
        print(self.decoder)
Ejemplo n.º 2
0
    def __init__(self, maze_dims, action_dim, act_proc_dim, z_dim,
                 x_encoder_specs, pre_gru_specs, gru_specs, prior_part_specs,
                 inference_part_specs, decoder_part_specs, masked_latent):
        super().__init__()

        self.act_proc_dim = act_proc_dim
        self.action_fc = nn.Linear(action_dim, self.act_proc_dim, bias=True)

        in_ch = maze_dims[0]
        in_h = maze_dims[1]
        self.x_encoder, out_ch, out_h = make_conv_net(in_ch, in_h,
                                                      x_encoder_specs)
        x_enc_channels = out_ch
        x_enc_h = out_h
        self.x_enc_ch = out_ch
        self.x_enc_h = out_h
        flat_x_enc_dim = x_enc_channels * x_enc_h * x_enc_h

        self.prior_fc_seq, hidden_dim = make_fc_net(
            self.act_proc_dim + gru_specs['hidden_size'], prior_part_specs)
        self.prior_mean_fc = nn.Linear(hidden_dim, z_dim, bias=True)
        self.prior_log_cov_fc = nn.Linear(hidden_dim, z_dim, bias=True)

        self.posterior_fc_seq, hidden_dim = make_fc_net(
            self.act_proc_dim + gru_specs['hidden_size'] + flat_x_enc_dim,
            inference_part_specs)
        self.posterior_mean_fc = nn.Linear(hidden_dim, z_dim, bias=True)
        self.posterior_log_cov_fc = nn.Linear(hidden_dim, z_dim, bias=True)

        self.pre_gru_seq, hidden_dim = make_fc_net(
            self.act_proc_dim + flat_x_enc_dim + z_dim, pre_gru_specs)

        self.gru_cell = nn.GRUCell(hidden_dim,
                                   gru_specs['hidden_size'],
                                   bias=True)
        self.h_dim = [gru_specs['hidden_size']]

        # models for the decoding/generation
        self.recon_fc_seq, out_h = make_fc_net(z_dim + self.h_dim[0],
                                               decoder_part_specs['fc_part'])
        assert out_h == x_enc_h * x_enc_h * x_enc_channels
        # just for convenience we use these dims
        self.recon_upconv_seq, out_ch, out_h = make_upconv_net(
            x_enc_channels, x_enc_h, decoder_part_specs['conv_part'])
        self.recon_mean_conv = nn.Conv2d(out_ch,
                                         3,
                                         3,
                                         stride=1,
                                         padding=1,
                                         bias=True)
        self.recon_log_cov_conv = nn.Conv2d(out_ch,
                                            3,
                                            3,
                                            stride=1,
                                            padding=1,
                                            bias=True)
        assert out_h == maze_dims[1]
Ejemplo n.º 3
0
    def __init__(
        self,
        maze_dims,
        action_proc_dim,
        z_dim,
        x_encoder_specs,
        pre_lstm_dim,
        lstm_dim,
        prior_part_specs,
        inference_part_specs,
        decoder_part_specs,
    ):
        super().__init__()

        in_ch = maze_dims[0]
        in_h = maze_dims[1]
        self.x_encoder, out_ch, out_h = make_conv_net(in_ch, in_h, x_encoder_specs)
        x_enc_channels = out_ch
        x_enc_h = out_h

        self.prior_action_fc = nn.Linear(4, action_proc_dim, bias=True)
        self.post_action_fc = nn.Linear(4, action_proc_dim, bias=True)
        self.recon_action_fc = nn.Linear(4, action_proc_dim, bias=True)
        self.pre_lstm_action_fc = nn.Linear(4, action_proc_dim, bias=True)

        self.lstm = nn.LSTMCell(
            pre_lstm_dim, lstm_dim, bias=True
        )

        self.attention_seq = nn.Sequential(
            nn.Linear(lstm_dim + action_proc_dim, lstm_dim, bias=False),
            nn.BatchNorm1d(lstm_dim),
            nn.ReLU(),
            nn.Linear(lstm_dim, lstm_dim),
            # nn.Sigmoid()
            # nn.Softmax()
        )

        self.prior_fc_seq, hidden_dim = make_fc_net(lstm_dim + action_proc_dim, prior_part_specs)
        self.prior_mean_fc = nn.Linear(hidden_dim, z_dim, bias=True)
        self.prior_log_cov_fc = nn.Linear(hidden_dim, z_dim, bias=True)

        out_ch = gru_specs['num_channels']

        # models for the posterior
        self.posterior_fc_seq, hidden_dim = make_fc_net(lstm_dim + x_enc_channels*x_enc_h*x_enc_h + action_proc_dim, inference_part_specs)
        self.posterior_mean_fc = nn.Linear(hidden_dim, z_dim, bias=True)
        self.posterior_log_cov_fc = nn.Linear(hidden_dim, z_dim, bias=True)

        # models for the decoding/generation
        self.recon_fc_seq, out_h = make_fc_net(z_dim + lstm_dim + action_proc_dim, decoder_part_specs['fc_part_specs'])
        self.recon_upconv_seq, out_ch, out_h = make_upconv_net(gru_specs['num_channels'] + z_dim, self.h_dim[1], decoder_part_specs['upconv_part_specs'])
        self.recon_mean_conv = nn.Conv2d(out_ch, 3, 3, stride=1, padding=1, bias=True)
        self.recon_log_cov_conv = nn.Conv2d(out_ch, 3, 3, stride=1, padding=1, bias=True)
        assert out_h == maze_dims[1]
Ejemplo n.º 4
0
    def __init__(
        self,
        maze_dims,
        z_dim,
        encoder_specs,
        decoder_specs
    ):
        super().__init__()

        in_ch = maze_dims[0]
        in_h = maze_dims[1]
        
        # make the encoder
        self.encoder_conv_seq, x_enc_ch, x_enc_h = make_conv_net(in_ch, in_h, encoder_specs['conv_part_specs'])
        self.x_enc_ch = x_enc_ch
        self.x_enc_h = x_enc_h
        flat_inter_img_dim = x_enc_ch * x_enc_h * x_enc_h


        self.z_mask_conv_seq, _, _ = make_conv_net(
            x_enc_ch, x_enc_h,
            {
                'kernel_sizes': [3],
                'num_channels': [64],
                'strides': [1],
                'paddings': [1],
                'use_bn': True
            }
        )
        self.z_mask_fc_seq, _ = make_fc_net(64*x_enc_h*x_enc_h, {'hidden_sizes': [1024], 'use_bn':True})
        self.z_mask_fc = nn.Linear(1024, 128, bias=True)

        self.z_mask_gen_fc_seq, _ = make_fc_net(128, {'hidden_sizes': [1024, 4*x_enc_h*x_enc_h], 'use_bn':True})
        self.z_mask_gen_conv = nn.Conv2d(4, 1, 3, stride=1, padding=1, bias=True)

        self.encoder_fc_seq, h_dim = make_fc_net(flat_inter_img_dim, encoder_specs['fc_part_specs'])

        self.z_mean_fc = nn.Linear(h_dim, z_dim, bias=True)
        self.z_log_cov_fc = nn.Linear(h_dim, z_dim, bias=True)

        # make the decoder
        self.decoder_fc_seq, h_dim = make_fc_net(z_dim, decoder_specs['fc_part_specs'])
        # assert h_dim == flat_inter_img_dim
        self.decoder_upconv_seq, out_ch, out_h = make_upconv_net(x_enc_ch, x_enc_h, decoder_specs['upconv_part_specs'])

        self.recon_mean_conv = nn.Conv2d(out_ch, 1, 1, stride=1, padding=0, bias=True)
        self.recon_log_cov_conv = nn.Conv2d(out_ch, 1, 1, stride=1, padding=0, bias=True)
        assert out_h == maze_dims[1], str(out_h) + ' != ' + str(maze_dims[1])
Ejemplo n.º 5
0
    def __init__(self, maze_dims, z_dim, encoder_specs, decoder_specs):
        super().__init__()

        in_ch = maze_dims[0]
        in_h = maze_dims[1]

        # make the encoder
        self.encoder_conv_seq, x_enc_ch, x_enc_h = make_conv_net(
            in_ch, in_h, encoder_specs['conv_part_specs'])
        self.x_enc_ch = x_enc_ch
        self.x_enc_h = x_enc_h
        flat_inter_img_dim = x_enc_ch * x_enc_h * x_enc_h

        self.encoder_fc_seq, h_dim = make_fc_net(
            flat_inter_img_dim, encoder_specs['fc_part_specs'])

        self.z_mean_fc = nn.Linear(h_dim, z_dim, bias=True)
        self.z_log_cov_fc = nn.Linear(h_dim, z_dim, bias=True)

        # make the decoder
        self.decoder_fc_seq, h_dim = make_fc_net(
            z_dim, decoder_specs['fc_part_specs'])
        assert h_dim == flat_inter_img_dim
        self.decoder_upconv_seq, out_ch, out_h = make_upconv_net(
            x_enc_ch, x_enc_h, decoder_specs['upconv_part_specs'])

        self.recon_mean_conv = nn.Conv2d(out_ch,
                                         1,
                                         1,
                                         stride=1,
                                         padding=0,
                                         bias=True)
        self.recon_log_cov_conv = nn.Conv2d(out_ch,
                                            1,
                                            1,
                                            stride=1,
                                            padding=0,
                                            bias=True)
        assert out_h == maze_dims[1], str(out_h) + ' != ' + str(maze_dims[1])
Ejemplo n.º 6
0
    def __init__(self, maze_dims, z_dim, encoder_specs, decoder_specs):
        super().__init__()

        in_ch = maze_dims[0]
        in_h = maze_dims[1]

        # make the encoder
        self.encoder_conv_seq, x_enc_ch, x_enc_h = make_conv_net(
            in_ch, in_h, encoder_specs['conv_part_specs'])
        self.x_enc_ch = x_enc_ch
        self.x_enc_h = x_enc_h
        flat_inter_img_dim = x_enc_ch * x_enc_h * x_enc_h

        self.enc_mask_seq, _, _ = make_conv_net(
            x_enc_ch, x_enc_h, {
                'kernel_sizes': [3],
                'num_channels': [64],
                'strides': [2],
                'paddings': [1],
                'use_bn': True
            })
        self.enc_mask_conv = nn.Conv2d(64,
                                       1,
                                       1,
                                       stride=1,
                                       padding=0,
                                       bias=True)

        # meshgrid
        xv, yv = np.meshgrid(np.linspace(-1., 1., x_enc_h),
                             np.linspace(-1., 1., x_enc_h))
        xv, yv = xv[None, None, ...], yv[None, None, ...]
        xv, yv = torch.FloatTensor(xv), torch.FloatTensor(yv)
        self.mesh = torch.cat([xv, yv], 1)
        self.mesh = Variable(self.mesh, requires_grad=False).cuda()

        self.encoder_fc_seq, h_dim = make_fc_net(
            flat_inter_img_dim, encoder_specs['fc_part_specs'])

        self.z_mean_fc = nn.Linear(h_dim, z_dim, bias=True)
        self.z_log_cov_fc = nn.Linear(h_dim, z_dim, bias=True)

        # make the decoder
        self.decoder_fc_seq, h_dim = make_fc_net(
            z_dim, decoder_specs['fc_part_specs'])
        # assert h_dim == flat_inter_img_dim
        self.decoder_upconv_seq, out_ch, out_h = make_upconv_net(
            130, x_enc_h, decoder_specs['upconv_part_specs'])

        self.recon_mean_conv = nn.Conv2d(out_ch,
                                         1,
                                         1,
                                         stride=1,
                                         padding=0,
                                         bias=True)
        self.recon_log_cov_conv = nn.Conv2d(out_ch,
                                            1,
                                            1,
                                            stride=1,
                                            padding=0,
                                            bias=True)
        assert out_h == maze_dims[1], str(out_h) + ' != ' + str(maze_dims[1])
Ejemplo n.º 7
0
    def __init__(self):
        super().__init__()

        # CONV VERSION
        self.img_h = 16
        inter_h = 4
        inter_ch = 32
        z_dim = inter_h * inter_h * inter_ch
        self.inter_h = inter_h
        self.inter_ch = inter_ch

        self.encoder = make_conv_net(
            3, self.img_h, {
                'kernel_sizes': [4, 4],
                'num_channels': [inter_ch, inter_ch],
                'strides': [2, 2],
                'paddings': [1, 1],
                'use_bn': True,
            })[0]

        # self.mask_net = make_conv_net(
        #     inter_ch, self.img_h, {
        #         'kernel_sizes': [3, 3],
        #         'num_channels': [4, 4],
        #         'strides': [1, 1],
        #         'paddings': [1, 1],
        #         'use_bn': True,
        #     }
        # )[0]
        self.mask_net = nn.Sequential(
            # self.mask_net,
            nn.Conv2d(inter_ch, 1, 1, 1, 0, bias=True),
            nn.Sigmoid())

        # self.z_mean_fc = nn.Linear(inter_h*inter_h*inter_ch, z_dim, bias=True)
        # self.z_log_cov_fc = nn.Linear(inter_h*inter_h*inter_ch, z_dim, bias=True)
        self.z_mean_conv = nn.Conv2d(inter_ch,
                                     inter_ch,
                                     3,
                                     stride=1,
                                     padding=1,
                                     bias=True)
        self.z_log_cov_conv = nn.Conv2d(inter_ch,
                                        inter_ch,
                                        3,
                                        stride=1,
                                        padding=1,
                                        bias=True)

        # self.decoder_fc = nn.Linear(z_dim, inter_h*inter_h*inter_ch)
        self.decoder = make_upconv_net(
            2 * inter_ch, inter_h, {
                'kernel_sizes': [4, 4],
                'num_channels': [inter_ch, inter_ch],
                'strides': [2, 2],
                'paddings': [1, 1],
                'output_paddings': [0, 0],
                'use_bn': True
            })[0]
        # self.decoder = nn.Sequential(
        #     self.decoder,
        #     nn.Conv2d(inter_ch, inter_ch, 5, stride=1, padding=2, bias=False),
        #     nn.BatchNorm2d(inter_ch),
        #     nn.ReLU()
        # )
        self.recon_mean_conv = nn.Conv2d(inter_ch,
                                         3,
                                         1,
                                         stride=1,
                                         padding=0,
                                         bias=True)
        self.recon_log_cov_conv = nn.Conv2d(inter_ch,
                                            3,
                                            1,
                                            stride=1,
                                            padding=0,
                                            bias=True)

        print(self.encoder)
        print(self.decoder)
Ejemplo n.º 8
0
    def __init__(self, maze_dims, z_dim, x_encoder_specs, z_seg_conv_specs,
                 z_seg_fc_specs, z_obj_conv_specs, z_obj_fc_specs,
                 z_seg_recon_fc_specs, z_seg_recon_upconv_specs,
                 z_obj_recon_fc_specs, z_obj_recon_upconv_specs,
                 recon_upconv_part_specs):
        super().__init__()

        in_ch = maze_dims[0]
        in_h = maze_dims[1]
        self.x_encoder, x_enc_ch, x_enc_h = make_conv_net(
            in_ch, in_h, x_encoder_specs)
        self.x_enc_ch = x_enc_ch
        self.x_enc_h = x_enc_h
        flat_inter_img_dim = x_enc_ch * x_enc_h * x_enc_h

        # self.convgru = ConvGRUCell(x_enc_ch, gru_specs['channels'], gru_specs['kernel_size'])
        # self.gru_ch = gru_specs['channels']

        self.z_seg_conv_seq, out_ch, out_h = make_conv_net(
            x_enc_ch + 1, x_enc_h, z_seg_conv_specs)
        self.z_seg_fc_seq, out_h = make_fc_net(out_ch * out_h * out_h,
                                               z_seg_fc_specs)
        self.z_seg_mean_fc = nn.Linear(out_h, z_dim, bias=True)
        self.z_seg_log_cov_fc = nn.Linear(out_h, z_dim, bias=True)

        # self.z_obj_conv_seq, z_conv_ch, z_conv_h = make_conv_net(x_enc_ch, x_enc_h, z_obj_conv_specs)
        # flat_dim = z_conv_ch*z_conv_h*z_conv_h
        # self.z_conv_ch, self.z_conv_h = z_conv_ch, z_conv_h
        self.z_obj_fc_seq, out_h = make_fc_net(flat_inter_img_dim,
                                               z_obj_fc_specs)
        self.z_obj_mean_fc = nn.Linear(out_h, z_dim, bias=True)
        self.z_obj_log_cov_fc = nn.Linear(out_h, z_dim, bias=True)

        self.z_seg_mask_fc_seq, out_h = make_fc_net(z_dim,
                                                    z_seg_recon_fc_specs)
        # print(out_h)
        # print(z_conv_ch, z_conv_h)
        # assert out_h == z_conv_h*z_conv_h*z_conv_ch
        self.z_seg_mask_upconv_seq, out_ch, out_h = make_upconv_net(
            x_enc_ch, x_enc_h, z_seg_recon_upconv_specs)
        self.z_seg_mask_conv = nn.Conv2d(out_ch,
                                         1,
                                         3,
                                         stride=1,
                                         padding=1,
                                         bias=True)
        print(out_h)

        self.z_obj_recon_fc_seq, z_recon_dim = make_fc_net(
            z_dim, z_obj_recon_fc_specs)
        # self.z_obj_recon_upconv_seq, out_ch, out_h = make_upconv_net(z_conv_ch, z_conv_h, z_obj_recon_upconv_specs)
        self.recon_upconv_seq, out_ch, out_h = make_upconv_net(
            x_enc_ch, x_enc_h, recon_upconv_part_specs)
        print(out_h)
        self.recon_mean_conv = nn.Conv2d(out_ch,
                                         1,
                                         1,
                                         stride=1,
                                         padding=0,
                                         bias=True)
        self.recon_log_cov_conv = nn.Conv2d(out_ch,
                                            1,
                                            1,
                                            stride=1,
                                            padding=0,
                                            bias=True)
        assert out_h == maze_dims[1], str(out_h) + ' != ' + str(maze_dims[1])
Ejemplo n.º 9
0
    def __init__(
        self,
        maze_dims,
        action_proc_dim,
        z_dim,
        pre_post_gru_dim,
        x_encoder_specs,
        decoder_part_specs,
    ):
        super().__init__()

        in_ch = maze_dims[0]
        in_h = maze_dims[1]
        self.x_encoder, x_enc_ch, x_enc_h = make_conv_net(in_ch, in_h, x_encoder_specs)
        self.x_enc_ch = x_enc_ch
        self.x_enc_h = x_enc_h
        flat_inter_img_dim = x_enc_ch * x_enc_h * x_enc_h

        lstm_dim = z_dim
        self.lstm_dim = z_dim

        self.prior_action_fc = nn.Linear(4, action_proc_dim, bias=True)
        self.post_action_fc = nn.Linear(4, action_proc_dim, bias=True)
        self.recon_action_fc = nn.Linear(4, action_proc_dim, bias=True)
        self.mask_action_fc = nn.Linear(4, action_proc_dim, bias=True)

        print(self.prior_action_fc)

        self.prior_pre_gru_fc = nn.Sequential(
            nn.Linear(z_dim + action_proc_dim, flat_inter_img_dim, bias=False),
            nn.BatchNorm1d(flat_inter_img_dim),
            nn.ReLU()
        )
        self.prior_mean_gru = nn.GRUCell(flat_inter_img_dim, z_dim, bias=True)
        self.prior_log_cov_seq = nn.Sequential(
            nn.Linear(flat_inter_img_dim, z_dim, bias=False),
            nn.BatchNorm1d(z_dim),
            nn.ReLU(),
            nn.Linear(z_dim, z_dim, bias=True)
        )
        self.post_mean_gru = nn.GRUCell(pre_post_gru_dim, z_dim, bias=True)
        self.post_log_cov_seq = nn.Sequential(
            nn.Linear(pre_post_gru_dim + z_dim, z_dim, bias=False),
            nn.BatchNorm1d(z_dim),
            nn.ReLU(),
            nn.Linear(z_dim, z_dim, bias=True)
        )

        print(self.prior_mean_gru)
        print(self.post_mean_gru)

        self.attention_seq = nn.Sequential(
            nn.Linear(lstm_dim + action_proc_dim, lstm_dim, bias=False),
            nn.BatchNorm1d(lstm_dim),
            nn.ReLU(),
            nn.Linear(lstm_dim, lstm_dim),
            # nn.Sigmoid()
            # nn.Softmax()
        )
        print(self.attention_seq)

        self.pre_post_gru_fc = nn.Sequential(
            nn.Linear(flat_inter_img_dim + action_proc_dim, pre_post_gru_dim, bias=False),
            nn.BatchNorm1d(pre_post_gru_dim),
            nn.ReLU(),
        )
        print(self.pre_post_gru_fc)

        # models for the decoding/generation
        self.recon_fc_seq = nn.Sequential(
            nn.Linear(lstm_dim, flat_inter_img_dim, bias=False),
            nn.BatchNorm1d(flat_inter_img_dim),
            nn.ReLU(),
        )
        self.recon_upconv_seq, out_ch, out_h = make_upconv_net(x_enc_ch, x_enc_h, decoder_part_specs)
        self.recon_mean_conv = nn.Conv2d(out_ch, 3, 1, stride=1, padding=0, bias=True)
        self.recon_log_cov_conv = nn.Conv2d(out_ch, 3, 1, stride=1, padding=0, bias=True)
        assert out_h == maze_dims[1], str(out_h) + ' != ' + str(maze_dims[1])
Ejemplo n.º 10
0
    def __init__(self, maze_dims, action_dim, z_dim, x_encoder_specs,
                 gru_specs, prior_part_specs, inference_part_specs,
                 decoder_part_specs, masked_latent):
        super().__init__()

        in_ch = maze_dims[0]
        in_h = maze_dims[1]
        self.x_encoder, out_ch, out_h = make_conv_net(in_ch, in_h,
                                                      x_encoder_specs)
        x_enc_channels = out_ch
        x_enc_h = out_h

        # self.gru_cell = nn.GRUCell(
        #     out_ch * out_h * out_h, gru_specs['hidden_size'], bias=True
        # )
        # self.h_dim = gru_specs['hidden_size']
        # self.gru_proc_fc = nn.Linear(gru_specs['hidden_size'], out_ch * out_h * out_h, bias=True)

        self.conv_gru_cell = nn.Conv2d(x_enc_channels + z_dim + 1,
                                       gru_specs['num_channels'],
                                       gru_specs['kernel_size'],
                                       stride=1,
                                       padding=1,
                                       bias=True)
        # self.conv_gru_cell = ConvGRUCell(
        #     x_enc_channels + z_dim + 1,
        #     gru_specs['num_channels'],
        #     gru_specs['kernel_size']
        # )
        out_ch = gru_specs['num_channels']

        # gru cell does not change the size of the input (I think :P)
        self.h_dim = [out_ch, out_h, out_h]
        self.action_fc = nn.Linear(action_dim, out_h * out_h, bias=True)

        self.masked_latent = masked_latent
        if masked_latent:
            self.mask_seq = nn.Sequential(
                nn.Conv2d(z_dim, 1, 3, stride=1, padding=1, bias=True),
                nn.Sigmoid())

        # models for the prior
        self.prior_conv_seq, out_ch, out_h = make_conv_net(
            gru_specs['num_channels'] + 1, out_h, prior_part_specs)
        self.prior_mean_conv = nn.Conv2d(out_ch,
                                         z_dim,
                                         3,
                                         stride=1,
                                         padding=1,
                                         bias=True)
        self.prior_log_cov_conv = nn.Conv2d(out_ch,
                                            z_dim,
                                            3,
                                            stride=1,
                                            padding=1,
                                            bias=True)

        # self.prior_fc_seq, hidden_dim = make_fc_net(int(np.prod(self.h_dim)) + action_dim, prior_part_specs)
        # self.prior_mean_fc = nn.Linear(hidden_dim, z_dim, bias=True)
        # self.prior_log_cov_fc = nn.Linear(hidden_dim, z_dim, bias=True)

        # models for the posterior
        self.posterior_conv_seq, out_ch, out_h = make_conv_net(
            x_enc_channels + gru_specs['num_channels'] + 1, x_enc_h,
            inference_part_specs['conv_part_specs'])
        self.posterior_mean_conv = nn.Conv2d(gru_specs['num_channels'],
                                             z_dim,
                                             3,
                                             stride=1,
                                             padding=1,
                                             bias=True)
        self.posterior_log_cov_conv = nn.Conv2d(gru_specs['num_channels'],
                                                z_dim,
                                                3,
                                                stride=1,
                                                padding=1,
                                                bias=True)
        # hidden_dim = out_ch * out_h * out_h
        # hidden_dim = x_enc_h * x_enc_h * x_enc_channels + self.h_dim

        # self.posterior_fc_seq, hidden_dim = make_fc_net(hidden_dim, inference_part_specs['fc_part_specs'])
        # self.posterior_mean_fc = nn.Linear(hidden_dim, z_dim, bias=True)
        # self.posterior_log_cov_fc = nn.Linear(hidden_dim, z_dim, bias=True)

        # models for the decoding/generation
        # self.z_encoder, z_enc_dim = make_fc_net(z_dim, z_encoder_specs)
        # assert z_enc_dim == int(np.prod(self.h_dim[1:])), 'z not encoded to right size'
        # self.recon_fc_seq, out_h = make_fc_net(z_dim + self.h_dim, decoder_part_specs['fc_part_specs'])
        self.recon_upconv_seq, out_ch, out_h = make_upconv_net(
            gru_specs['num_channels'] + z_dim, self.h_dim[1],
            decoder_part_specs)
        self.recon_mean_conv = nn.Conv2d(out_ch,
                                         3,
                                         3,
                                         stride=1,
                                         padding=1,
                                         bias=True)
        self.recon_log_cov_conv = nn.Conv2d(out_ch,
                                            3,
                                            3,
                                            stride=1,
                                            padding=1,
                                            bias=True)
        assert out_h == maze_dims[1]