Exemplo n.º 1
0
def create_model_brick():
    decoder = MLP(
        dims=[NLAT, GEN_HIDDEN, GEN_HIDDEN, GEN_HIDDEN, GEN_HIDDEN, INPUT_DIM],
        activations=[Sequence([BatchNormalization(GEN_HIDDEN).apply,
                               GEN_ACTIVATION().apply],
                              name='decoder_h1'),
                     Sequence([BatchNormalization(GEN_HIDDEN).apply,
                               GEN_ACTIVATION().apply],
                              name='decoder_h2'),
                     Sequence([BatchNormalization(GEN_HIDDEN).apply,
                               GEN_ACTIVATION().apply],
                              name='decoder_h3'),
                     Sequence([BatchNormalization(GEN_HIDDEN).apply,
                               GEN_ACTIVATION().apply],
                              name='decoder_h4'),
                     Identity(name='decoder_out')],
        use_bias=False,
        name='decoder')

    discriminator = Sequence(
        application_methods=[
            LinearMaxout(
                input_dim=INPUT_DIM,
                output_dim=DISC_HIDDEN,
                num_pieces=MAXOUT_PIECES,
                weights_init=GAUSSIAN_INIT,
                biases_init=ZERO_INIT,
                name='discriminator_h1').apply,
            LinearMaxout(
                input_dim=DISC_HIDDEN,
                output_dim=DISC_HIDDEN,
                num_pieces=MAXOUT_PIECES,
                weights_init=GAUSSIAN_INIT,
                biases_init=ZERO_INIT,
                name='discriminator_h2').apply,
            LinearMaxout(
                input_dim=DISC_HIDDEN,
                output_dim=DISC_HIDDEN,
                num_pieces=MAXOUT_PIECES,
                weights_init=GAUSSIAN_INIT,
                biases_init=ZERO_INIT,
                name='discriminator_h3').apply,
            Linear(
                input_dim=DISC_HIDDEN,
                output_dim=1,
                weights_init=GAUSSIAN_INIT,
                biases_init=ZERO_INIT,
                name='discriminator_out').apply],
        name='discriminator')

    gan = GAN(decoder=decoder, discriminator=discriminator,
              weights_init=GAUSSIAN_INIT, biases_init=ZERO_INIT, name='gan')
    gan.push_allocation_config()
    decoder.linear_transformations[-1].use_bias = True
    gan.initialize()

    return gan
Exemplo n.º 2
0
def create_model_bricks(z_dim, image_size, depth):

    g_image_size = image_size
    g_image_size2 = g_image_size / 2
    g_image_size3 = g_image_size / 4
    g_image_size4 = g_image_size / 8
    g_image_size5 = g_image_size / 16

    encoder_layers = []
    if depth > 0:
        encoder_layers = encoder_layers + [
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=32,
                          name='conv1'),
            SpatialBatchNormalization(name='batch_norm1'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=32,
                          name='conv2'),
            SpatialBatchNormalization(name='batch_norm2'),
            Rectifier(),
            Convolutional(
                filter_size=(2, 2), step=(2, 2), num_filters=32, name='conv3'),
            SpatialBatchNormalization(name='batch_norm3'),
            Rectifier()
        ]
    if depth > 1:
        encoder_layers = encoder_layers + [
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=64,
                          name='conv4'),
            SpatialBatchNormalization(name='batch_norm4'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=64,
                          name='conv5'),
            SpatialBatchNormalization(name='batch_norm5'),
            Rectifier(),
            Convolutional(
                filter_size=(2, 2), step=(2, 2), num_filters=64, name='conv6'),
            SpatialBatchNormalization(name='batch_norm6'),
            Rectifier()
        ]
    if depth > 2:
        encoder_layers = encoder_layers + [
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=128,
                          name='conv7'),
            SpatialBatchNormalization(name='batch_norm7'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=128,
                          name='conv8'),
            SpatialBatchNormalization(name='batch_norm8'),
            Rectifier(),
            Convolutional(
                filter_size=(2, 2), step=(2, 2), num_filters=128,
                name='conv9'),
            SpatialBatchNormalization(name='batch_norm9'),
            Rectifier()
        ]
    if depth > 3:
        encoder_layers = encoder_layers + [
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=256,
                          name='conv10'),
            SpatialBatchNormalization(name='batch_norm10'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=256,
                          name='conv11'),
            SpatialBatchNormalization(name='batch_norm11'),
            Rectifier(),
            Convolutional(filter_size=(2, 2),
                          step=(2, 2),
                          num_filters=256,
                          name='conv12'),
            SpatialBatchNormalization(name='batch_norm12'),
            Rectifier(),
        ]
    if depth > 4:
        encoder_layers = encoder_layers + [
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=512,
                          name='conv13'),
            SpatialBatchNormalization(name='batch_norm13'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=512,
                          name='conv14'),
            SpatialBatchNormalization(name='batch_norm14'),
            Rectifier(),
            Convolutional(filter_size=(2, 2),
                          step=(2, 2),
                          num_filters=512,
                          name='conv15'),
            SpatialBatchNormalization(name='batch_norm15'),
            Rectifier()
        ]

    decoder_layers = []
    if depth > 4:
        decoder_layers = decoder_layers + [
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=512,
                          name='conv_n3'),
            SpatialBatchNormalization(name='batch_norm_n3'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=512,
                          name='conv_n2'),
            SpatialBatchNormalization(name='batch_norm_n2'),
            Rectifier(),
            ConvolutionalTranspose(
                filter_size=(2, 2),
                step=(2, 2),
                original_image_size=(g_image_size5, g_image_size5),
                num_filters=512,
                name='conv_n1'),
            SpatialBatchNormalization(name='batch_norm_n1'),
            Rectifier()
        ]

    if depth > 3:
        decoder_layers = decoder_layers + [
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=256,
                          name='conv1'),
            SpatialBatchNormalization(name='batch_norm1'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=256,
                          name='conv2'),
            SpatialBatchNormalization(name='batch_norm2'),
            Rectifier(),
            ConvolutionalTranspose(
                filter_size=(2, 2),
                step=(2, 2),
                original_image_size=(g_image_size4, g_image_size4),
                num_filters=256,
                name='conv3'),
            SpatialBatchNormalization(name='batch_norm3'),
            Rectifier()
        ]

    if depth > 2:
        decoder_layers = decoder_layers + [
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=128,
                          name='conv4'),
            SpatialBatchNormalization(name='batch_norm4'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=128,
                          name='conv5'),
            SpatialBatchNormalization(name='batch_norm5'),
            Rectifier(),
            ConvolutionalTranspose(
                filter_size=(2, 2),
                step=(2, 2),
                original_image_size=(g_image_size3, g_image_size3),
                num_filters=128,
                name='conv6'),
            SpatialBatchNormalization(name='batch_norm6'),
            Rectifier()
        ]

    if depth > 1:
        decoder_layers = decoder_layers + [
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=64,
                          name='conv7'),
            SpatialBatchNormalization(name='batch_norm7'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=64,
                          name='conv8'),
            SpatialBatchNormalization(name='batch_norm8'),
            Rectifier(),
            ConvolutionalTranspose(
                filter_size=(2, 2),
                step=(2, 2),
                original_image_size=(g_image_size2, g_image_size2),
                num_filters=64,
                name='conv9'),
            SpatialBatchNormalization(name='batch_norm9'),
            Rectifier()
        ]

    if depth > 0:
        decoder_layers = decoder_layers + [
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=32,
                          name='conv10'),
            SpatialBatchNormalization(name='batch_norm10'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=32,
                          name='conv11'),
            SpatialBatchNormalization(name='batch_norm11'),
            Rectifier(),
            ConvolutionalTranspose(
                filter_size=(2, 2),
                step=(2, 2),
                original_image_size=(g_image_size, g_image_size),
                num_filters=32,
                name='conv12'),
            SpatialBatchNormalization(name='batch_norm12'),
            Rectifier()
        ]

    decoder_layers = decoder_layers + [
        Convolutional(filter_size=(1, 1), num_filters=3, name='conv_out'),
        Logistic()
    ]

    print("creating model of depth {} with {} encoder and {} decoder layers".
          format(depth, len(encoder_layers), len(decoder_layers)))

    encoder_convnet = ConvolutionalSequence(
        layers=encoder_layers,
        num_channels=3,
        image_size=(g_image_size, g_image_size),
        use_bias=False,
        weights_init=IsotropicGaussian(0.033),
        biases_init=Constant(0),
        name='encoder_convnet')
    encoder_convnet.initialize()

    encoder_filters = numpy.prod(encoder_convnet.get_dim('output'))

    encoder_mlp = MLP(
        dims=[encoder_filters, 1000, z_dim],
        activations=[
            Sequence([BatchNormalization(1000).apply,
                      Rectifier().apply],
                     name='activation1'),
            Identity().apply
        ],
        weights_init=IsotropicGaussian(0.033),
        biases_init=Constant(0),
        name='encoder_mlp')
    encoder_mlp.initialize()

    decoder_mlp = BatchNormalizedMLP(
        activations=[Rectifier(), Rectifier()],
        dims=[encoder_mlp.output_dim // 2, 1000, encoder_filters],
        weights_init=IsotropicGaussian(0.033),
        biases_init=Constant(0),
        name='decoder_mlp')
    decoder_mlp.initialize()

    decoder_convnet = ConvolutionalSequence(
        layers=decoder_layers,
        num_channels=encoder_convnet.get_dim('output')[0],
        image_size=encoder_convnet.get_dim('output')[1:],
        use_bias=False,
        weights_init=IsotropicGaussian(0.033),
        biases_init=Constant(0),
        name='decoder_convnet')
    decoder_convnet.initialize()

    return encoder_convnet, encoder_mlp, decoder_convnet, decoder_mlp
Exemplo n.º 3
0
def create_model_bricks():
    encoder_convnet = ConvolutionalSequence(
        layers=[
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=32,
                          name='conv1'),
            SpatialBatchNormalization(name='batch_norm1'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=32,
                          name='conv2'),
            SpatialBatchNormalization(name='batch_norm2'),
            Rectifier(),
            Convolutional(filter_size=(2, 2),
                          step=(2, 2),
                          num_filters=32,
                          name='conv3'),
            SpatialBatchNormalization(name='batch_norm3'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=64,
                          name='conv4'),
            SpatialBatchNormalization(name='batch_norm4'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=64,
                          name='conv5'),
            SpatialBatchNormalization(name='batch_norm5'),
            Rectifier(),
            Convolutional(filter_size=(2, 2),
                          step=(2, 2),
                          num_filters=64,
                          name='conv6'),
            SpatialBatchNormalization(name='batch_norm6'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=128,
                          name='conv7'),
            SpatialBatchNormalization(name='batch_norm7'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=128,
                          name='conv8'),
            SpatialBatchNormalization(name='batch_norm8'),
            Rectifier(),
            Convolutional(filter_size=(2, 2),
                          step=(2, 2),
                          num_filters=128,
                          name='conv9'),
            SpatialBatchNormalization(name='batch_norm9'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=256,
                          name='conv10'),
            SpatialBatchNormalization(name='batch_norm10'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=256,
                          name='conv11'),
            SpatialBatchNormalization(name='batch_norm11'),
            Rectifier(),
            Convolutional(filter_size=(2, 2),
                          step=(2, 2),
                          num_filters=256,
                          name='conv12'),
            SpatialBatchNormalization(name='batch_norm12'),
            Rectifier(),
        ],
        num_channels=3,
        image_size=(64, 64),
        use_bias=False,
        weights_init=IsotropicGaussian(0.033),
        biases_init=Constant(0),
        name='encoder_convnet')
    encoder_convnet.initialize()

    encoder_filters = numpy.prod(encoder_convnet.get_dim('output'))

    encoder_mlp = MLP(
        dims=[encoder_filters, 1000, 1000],
        activations=[
            Sequence([BatchNormalization(1000).apply,
                      Rectifier().apply],
                     name='activation1'),
            Identity().apply
        ],
        weights_init=IsotropicGaussian(0.033),
        biases_init=Constant(0),
        name='encoder_mlp')
    encoder_mlp.initialize()

    decoder_mlp = BatchNormalizedMLP(
        activations=[Rectifier(), Rectifier()],
        dims=[encoder_mlp.output_dim // 2, 1000, encoder_filters],
        weights_init=IsotropicGaussian(0.033),
        biases_init=Constant(0),
        name='decoder_mlp')
    decoder_mlp.initialize()

    decoder_convnet = ConvolutionalSequence(
        layers=[
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=256,
                          name='conv1'),
            SpatialBatchNormalization(name='batch_norm1'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=256,
                          name='conv2'),
            SpatialBatchNormalization(name='batch_norm2'),
            Rectifier(),
            ConvolutionalTranspose(filter_size=(2, 2),
                                   step=(2, 2),
                                   original_image_size=(8, 8),
                                   num_filters=256,
                                   name='conv3'),
            SpatialBatchNormalization(name='batch_norm3'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=128,
                          name='conv4'),
            SpatialBatchNormalization(name='batch_norm4'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=128,
                          name='conv5'),
            SpatialBatchNormalization(name='batch_norm5'),
            Rectifier(),
            ConvolutionalTranspose(filter_size=(2, 2),
                                   step=(2, 2),
                                   original_image_size=(16, 16),
                                   num_filters=128,
                                   name='conv6'),
            SpatialBatchNormalization(name='batch_norm6'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=64,
                          name='conv7'),
            SpatialBatchNormalization(name='batch_norm7'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=64,
                          name='conv8'),
            SpatialBatchNormalization(name='batch_norm8'),
            Rectifier(),
            ConvolutionalTranspose(filter_size=(2, 2),
                                   step=(2, 2),
                                   original_image_size=(32, 32),
                                   num_filters=64,
                                   name='conv9'),
            SpatialBatchNormalization(name='batch_norm9'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=32,
                          name='conv10'),
            SpatialBatchNormalization(name='batch_norm10'),
            Rectifier(),
            Convolutional(filter_size=(3, 3),
                          border_mode=(1, 1),
                          num_filters=32,
                          name='conv11'),
            SpatialBatchNormalization(name='batch_norm11'),
            Rectifier(),
            ConvolutionalTranspose(filter_size=(2, 2),
                                   step=(2, 2),
                                   original_image_size=(64, 64),
                                   num_filters=32,
                                   name='conv12'),
            SpatialBatchNormalization(name='batch_norm12'),
            Rectifier(),
            Convolutional(filter_size=(1, 1), num_filters=3, name='conv_out'),
            Logistic(),
        ],
        num_channels=encoder_convnet.get_dim('output')[0],
        image_size=encoder_convnet.get_dim('output')[1:],
        use_bias=False,
        weights_init=IsotropicGaussian(0.033),
        biases_init=Constant(0),
        name='decoder_convnet')
    decoder_convnet.initialize()

    return encoder_convnet, encoder_mlp, decoder_convnet, decoder_mlp
Exemplo n.º 4
0
        def create_model_brick():
            encoder_mapping = MLP(
                dims=[2 * INPUT_DIM, GEN_HIDDEN, GEN_HIDDEN, NLAT],
                activations=[
                    Sequence([
                        BatchNormalization(GEN_HIDDEN).apply,
                        GEN_ACTIVATION().apply
                    ],
                             name='encoder_h1'),
                    Sequence([
                        BatchNormalization(GEN_HIDDEN).apply,
                        GEN_ACTIVATION().apply
                    ],
                             name='encoder_h2'),
                    Identity(name='encoder_out')
                ],
                use_bias=False,
                name='encoder_mapping')
            encoder = COVConditional(encoder_mapping, (INPUT_DIM, ),
                                     name='encoder')

            decoder_mapping = MLP(dims=[
                NLAT, GEN_HIDDEN, GEN_HIDDEN, GEN_HIDDEN, GEN_HIDDEN, INPUT_DIM
            ],
                                  activations=[
                                      Sequence([
                                          BatchNormalization(GEN_HIDDEN).apply,
                                          GEN_ACTIVATION().apply
                                      ],
                                               name='decoder_h1'),
                                      Sequence([
                                          BatchNormalization(GEN_HIDDEN).apply,
                                          GEN_ACTIVATION().apply
                                      ],
                                               name='decoder_h2'),
                                      Sequence([
                                          BatchNormalization(GEN_HIDDEN).apply,
                                          GEN_ACTIVATION().apply
                                      ],
                                               name='decoder_h3'),
                                      Sequence([
                                          BatchNormalization(GEN_HIDDEN).apply,
                                          GEN_ACTIVATION().apply
                                      ],
                                               name='decoder_h4'),
                                      Identity(name='decoder_out')
                                  ],
                                  use_bias=False,
                                  name='decoder_mapping')
            decoder = DeterministicConditional(decoder_mapping, name='decoder')

            x_discriminator = Identity(name='x_discriminator')
            z_discriminator = Identity(name='z_discriminator')
            joint_discriminator = Sequence(application_methods=[
                LinearMaxout(input_dim=INPUT_DIM + NLAT,
                             output_dim=DISC_HIDDEN,
                             num_pieces=MAXOUT_PIECES,
                             weights_init=GAUSSIAN_INIT,
                             biases_init=ZERO_INIT,
                             name='discriminator_h1').apply,
                LinearMaxout(input_dim=DISC_HIDDEN,
                             output_dim=DISC_HIDDEN,
                             num_pieces=MAXOUT_PIECES,
                             weights_init=GAUSSIAN_INIT,
                             biases_init=ZERO_INIT,
                             name='discriminator_h2').apply,
                LinearMaxout(input_dim=DISC_HIDDEN,
                             output_dim=DISC_HIDDEN,
                             num_pieces=MAXOUT_PIECES,
                             weights_init=GAUSSIAN_INIT,
                             biases_init=ZERO_INIT,
                             name='discriminator_h3').apply,
                Linear(input_dim=DISC_HIDDEN,
                       output_dim=1,
                       weights_init=GAUSSIAN_INIT,
                       biases_init=ZERO_INIT,
                       name='discriminator_out').apply
            ],
                                           name='joint_discriminator')
            discriminator = XZJointDiscriminator(x_discriminator,
                                                 z_discriminator,
                                                 joint_discriminator,
                                                 name='discriminator')

            ali = ALI(encoder=encoder,
                      decoder=decoder,
                      discriminator=discriminator,
                      weights_init=GAUSSIAN_INIT,
                      biases_init=ZERO_INIT,
                      name='ali')
            ali.push_allocation_config()
            encoder_mapping.linear_transformations[-1].use_bias = True
            decoder_mapping.linear_transformations[-1].use_bias = True
            ali.initialize()

            print("Number of parameters in discriminator: {}".format(
                numpy.sum([
                    numpy.prod(v.shape.eval()) for v in Selector(
                        ali.discriminator).get_parameters().values()
                ])))
            print("Number of parameters in encoder: {}".format(
                numpy.sum([
                    numpy.prod(v.shape.eval())
                    for v in Selector(ali.encoder).get_parameters().values()
                ])))
            print("Number of parameters in decoder: {}".format(
                numpy.sum([
                    numpy.prod(v.shape.eval())
                    for v in Selector(ali.decoder).get_parameters().values()
                ])))

            return ali
Exemplo n.º 5
0
    def create_model_brick(self):
        decoder = MLP(
            dims=[
                self._config["num_zdim"], self._config["gen_hidden_size"],
                self._config["gen_hidden_size"],
                self._config["gen_hidden_size"],
                self._config["gen_hidden_size"], self._config["num_xdim"]
            ],
            activations=[
                Sequence([
                    BatchNormalization(self._config["gen_hidden_size"]).apply,
                    self._config["gen_activation"]().apply
                ],
                         name='decoder_h1'),
                Sequence([
                    BatchNormalization(self._config["gen_hidden_size"]).apply,
                    self._config["gen_activation"]().apply
                ],
                         name='decoder_h2'),
                Sequence([
                    BatchNormalization(self._config["gen_hidden_size"]).apply,
                    self._config["gen_activation"]().apply
                ],
                         name='decoder_h3'),
                Sequence([
                    BatchNormalization(self._config["gen_hidden_size"]).apply,
                    self._config["gen_activation"]().apply
                ],
                         name='decoder_h4'),
                Identity(name='decoder_out')
            ],
            use_bias=False,
            name='decoder')

        discriminator = Sequence(application_methods=[
            LinearMaxout(input_dim=self._config["num_xdim"] *
                         self._config["num_packing"],
                         output_dim=self._config["disc_hidden_size"],
                         num_pieces=self._config["disc_maxout_pieces"],
                         weights_init=IsotropicGaussian(
                             self._config["weights_init_std"]),
                         biases_init=self._config["biases_init"],
                         name='discriminator_h1').apply,
            LinearMaxout(input_dim=self._config["disc_hidden_size"],
                         output_dim=self._config["disc_hidden_size"],
                         num_pieces=self._config["disc_maxout_pieces"],
                         weights_init=IsotropicGaussian(
                             self._config["weights_init_std"]),
                         biases_init=self._config["biases_init"],
                         name='discriminator_h2').apply,
            LinearMaxout(input_dim=self._config["disc_hidden_size"],
                         output_dim=self._config["disc_hidden_size"],
                         num_pieces=self._config["disc_maxout_pieces"],
                         weights_init=IsotropicGaussian(
                             self._config["weights_init_std"]),
                         biases_init=self._config["biases_init"],
                         name='discriminator_h3').apply,
            Linear(input_dim=self._config["disc_hidden_size"],
                   output_dim=1,
                   weights_init=IsotropicGaussian(
                       self._config["weights_init_std"]),
                   biases_init=self._config["biases_init"],
                   name='discriminator_out').apply
        ],
                                 name='discriminator')

        gan = PacGAN(decoder=decoder,
                     discriminator=discriminator,
                     weights_init=IsotropicGaussian(
                         self._config["weights_init_std"]),
                     biases_init=self._config["biases_init"],
                     name='gan')
        gan.push_allocation_config()
        decoder.linear_transformations[-1].use_bias = True
        gan.initialize()

        print("Number of parameters in discriminator: {}".format(
            numpy.sum([
                numpy.prod(v.shape.eval())
                for v in Selector(gan.discriminator).get_parameters().values()
            ])))
        print("Number of parameters in decoder: {}".format(
            numpy.sum([
                numpy.prod(v.shape.eval())
                for v in Selector(gan.decoder).get_parameters().values()
            ])))

        return gan