def encoder(self, x):
        print('Encoder')
        print(x)
        if self.dataset == 'mnist':
            # x = tf.layers.conv2d(x, filters=32, kernel_size=3, strides=2, padding='valid', activation=tf.nn.relu)
            # print(x)
            # x = tf.layers.conv2d(x, filters=64, kernel_size=3, strides=2, padding='valid', activation=tf.nn.relu)
            # print(x)
            # x = tf.layers.conv2d(x, filters=128, kernel_size=3, strides=2, padding='valid', activation=tf.nn.relu)
            # print(x)
            # x = tf.layers.conv2d(x, filters=128, kernel_size=2, strides=1, padding='valid', activation=tf.nn.relu)
            # print(x)
            # x = tf.layers.conv2d(x, filters=32, kernel_size=3, strides=2, padding='valid', activation=tf.nn.relu)
            # print(x)
            x = conv2d('conv1',
                       x,
                       num_filters=32,
                       kernel_size=(8, 8),
                       padding='VALID',
                       stride=(4, 4),
                       initializer=orthogonal_initializer(np.sqrt(2)),
                       activation=tf.nn.relu,
                       is_training=self.is_training)
            print(x)

            x = conv2d('conv2',
                       x,
                       num_filters=64,
                       kernel_size=(4, 4),
                       padding='VALID',
                       stride=(2, 2),
                       initializer=orthogonal_initializer(np.sqrt(2)),
                       activation=tf.nn.relu,
                       is_training=self.is_training)
            print(x)
            x = conv2d('conv3',
                       x,
                       num_filters=64,
                       kernel_size=(2, 2),
                       padding='VALID',
                       stride=(1, 1),
                       initializer=orthogonal_initializer(np.sqrt(2)),
                       activation=tf.nn.relu,
                       is_training=self.is_training)
            print(x)
            # x = tf.layers.conv2d(x, filters=128, kernel_size=2, strides=1, padding='valid', activation=tf.nn.relu)
            # print(x)

        # elif self.dataset == 'breakout':
        #     ## OLD VERSION THAT IS VERY BIG!
        #     x = tf.layers.conv2d(x, filters=16, kernel_size=3, strides=2, padding='same', activation=tf.nn.relu)
        #     print(x)
        #     x = tf.layers.conv2d(x, filters=32, kernel_size=3, strides=2, padding='same', activation=tf.nn.relu)
        #     print(x)
        #     x = tf.layers.conv2d(x, filters=64, kernel_size=3, strides=2, padding='same', activation=tf.nn.relu)
        #     print(x)
        #     x = tf.layers.conv2d(x, filters=128, kernel_size=3, strides=2, padding='same', activation=tf.nn.relu)
        #     print(x)
        #     x = tf.layers.conv2d(x, filters=256, kernel_size=3, strides=2, padding='same', activation=tf.nn.relu)
        #     print(x)
        elif self.dataset == 'breakout':
            conv1 = conv2d('conv1',
                           x,
                           num_filters=32,
                           kernel_size=(8, 8),
                           padding='VALID',
                           stride=(4, 4),
                           initializer=orthogonal_initializer(np.sqrt(2)),
                           activation=tf.nn.relu,
                           is_training=self.is_training)
            print(conv1)

            conv2 = conv2d('conv2',
                           conv1,
                           num_filters=64,
                           kernel_size=(4, 4),
                           padding='VALID',
                           stride=(2, 2),
                           initializer=orthogonal_initializer(np.sqrt(2)),
                           activation=tf.nn.relu,
                           is_training=self.is_training)
            print(conv2)

            conv3 = conv2d('conv3',
                           conv2,
                           num_filters=64,
                           kernel_size=(3, 3),
                           padding='VALID',
                           stride=(1, 1),
                           initializer=orthogonal_initializer(np.sqrt(2)),
                           activation=tf.nn.relu,
                           is_training=self.is_training)
            print(conv3)
            x = conv3
        print()
        return x
    def decoder(self, z, reuse=False):
        print('Decoder')
        # first_conv_filters = 256
        first_conv_filters = 64
        decoder_input_size = self.encoder_out.shape[
            1] * self.encoder_out.shape[2] * first_conv_filters

        x = tf.layers.dense(z, decoder_input_size, activation=tf.nn.relu)
        print(x)
        # x = tf.reshape(x, [-1, 1, 1, decoder_input_size])
        x = tf.reshape(x, [
            -1, self.encoder_out.shape[1], self.encoder_out.shape[2],
            first_conv_filters
        ])
        # x = tf.layers.conv2d(x, filters=128, kernel_size=3, strides=2, padding='same', activation=tf.nn.relu)
        print(x)

        if self.dataset == 'mnist':
            x = tf.image.resize_images(x, (x.shape[1] * 6, x.shape[2] * 6))
            x = conv2d('conv_up1',
                       x,
                       num_filters=64,
                       kernel_size=(3, 3),
                       padding='VALID',
                       stride=(1, 1),
                       initializer=orthogonal_initializer(np.sqrt(2)),
                       activation=tf.nn.relu,
                       is_training=self.is_training)
            print(x)

            x = tf.image.resize_images(x, (x.shape[1] * 3, x.shape[2] * 3))
            x = conv2d('conv_up2',
                       x,
                       num_filters=64,
                       kernel_size=(3, 3),
                       padding='VALID',
                       stride=(1, 1),
                       initializer=orthogonal_initializer(np.sqrt(2)),
                       activation=tf.nn.relu,
                       is_training=self.is_training)
            print(x)

            x = tf.image.resize_images(x, (x.shape[1] * 3, x.shape[2] * 3))
            x = conv2d('conv_up3',
                       x,
                       num_filters=32,
                       kernel_size=(3, 3),
                       padding='VALID',
                       stride=(1, 1),
                       initializer=orthogonal_initializer(np.sqrt(2)),
                       activation=tf.nn.relu,
                       is_training=self.is_training)
            print(x)

            x = conv2d('conv_up4',
                       x,
                       num_filters=self.img_channels,
                       kernel_size=(1, 1),
                       padding='VALID',
                       stride=(1, 1),
                       initializer=orthogonal_initializer(np.sqrt(2)),
                       activation=None,
                       is_training=self.is_training)
            print(x)
        elif self.dataset == 'breakout':
            x = tf.image.resize_images(x, (x.shape[1] * 4, x.shape[2] * 4))
            # x = tf.layers.conv2d(x, filters=64, kernel_size=7, strides=1, padding='valid', activation=tf.nn.relu)
            x = conv2d('conv_up1',
                       x,
                       num_filters=64,
                       kernel_size=(7, 7),
                       padding='VALID',
                       stride=(1, 1),
                       initializer=orthogonal_initializer(np.sqrt(2)),
                       activation=tf.nn.relu,
                       is_training=self.is_training)
            print(x)
            x = tf.image.resize_images(x, (x.shape[1] * 4, x.shape[2] * 4))
            # x = tf.layers.conv2d(x, filters=64, kernel_size=7, strides=1, padding='same', activation=tf.nn.relu)
            x = conv2d('conv_up2',
                       x,
                       num_filters=64,
                       kernel_size=(7, 7),
                       padding='SAME',
                       stride=(1, 1),
                       initializer=orthogonal_initializer(np.sqrt(2)),
                       activation=tf.nn.relu,
                       is_training=self.is_training)
            print(x)
            # x = tf.layers.conv2d(x, filters=32, kernel_size=5, strides=1, padding='valid', activation=tf.nn.relu)
            x = conv2d('conv_up3',
                       x,
                       num_filters=32,
                       kernel_size=(5, 5),
                       padding='VALID',
                       stride=(1, 1),
                       initializer=orthogonal_initializer(np.sqrt(2)),
                       activation=tf.nn.relu,
                       is_training=self.is_training)
            print(x)
            # x = tf.layers.conv2d(x, filters=self.img_channels, kernel_size=1, strides=1, padding='same', activation=None)
            x = conv2d('conv_up4',
                       x,
                       num_filters=self.img_channels,
                       kernel_size=(1, 1),
                       padding='VALID',
                       stride=(1, 1),
                       initializer=orthogonal_initializer(np.sqrt(2)),
                       activation=None,
                       is_training=self.is_training)
            print(x)
        print()
        return x

        if 0:
            ## OLD LARGE CRAP

            x = tf.image.resize_images(x, (x.shape[1] * 2, x.shape[2] * 2))
            x = tf.layers.conv2d(x,
                                 filters=128,
                                 kernel_size=3,
                                 strides=1,
                                 padding='same',
                                 activation=tf.nn.relu)
            # x = tf.layers.conv2d_transpose(x, filters=128, kernel_size=2, strides=2, padding='valid', activation=tf.nn.relu)
            print(x)
            # x = tf.layers.conv2d_transpose(x, filters=64, kernel_size=2, strides=2, padding='valid', activation=tf.nn.relu)
            x = tf.image.resize_images(x, (x.shape[1] * 2, x.shape[2] * 2))
            x = tf.layers.conv2d(x,
                                 filters=64,
                                 kernel_size=3,
                                 strides=1,
                                 padding='same',
                                 activation=tf.nn.relu)
            print(x)
            # x = tf.layers.conv2d_transpose(x, filters=32, kernel_size=2, strides=2, padding='valid', activation=tf.nn.relu)
            x = tf.image.resize_images(x, (x.shape[1] * 2, x.shape[2] * 2))
            x = tf.layers.conv2d(x,
                                 filters=32,
                                 kernel_size=3,
                                 strides=1,
                                 padding='same',
                                 activation=tf.nn.relu)
            print(x)
            # x = tf.layers.conv2d_transpose(x, filters=16, kernel_size=2, strides=2, padding='valid', activation=tf.nn.relu)
            x = tf.image.resize_images(x, (x.shape[1] * 2, x.shape[2] * 2))
            x = tf.layers.conv2d(x,
                                 filters=16,
                                 kernel_size=3,
                                 strides=1,
                                 padding='same',
                                 activation=tf.nn.relu)
            print(x)

            if self.dataset == 'mnist':
                # x = tf.layers.conv2d_transpose(x, filters=8, kernel_size=2, strides=2, padding='valid', activation=tf.nn.relu)
                x = tf.image.resize_images(x, (x.shape[1] * 2, x.shape[2] * 2))
                x = tf.layers.conv2d(x,
                                     filters=8,
                                     kernel_size=5,
                                     strides=1,
                                     padding='valid',
                                     activation=tf.nn.relu)
                print(x)
                x = tf.layers.conv2d(x,
                                     filters=self.img_channels,
                                     kernel_size=1,
                                     strides=1,
                                     padding='valid',
                                     activation=tf.nn.relu)
            elif self.dataset == 'breakout':
                # x = tf.layers.conv2d_transpose(x, filters=8, kernel_size=2, strides=2, padding='valid', activation=tf.nn.relu)
                x = tf.image.resize_images(x, (x.shape[1] * 2, x.shape[2] * 2))
                x = tf.layers.conv2d(x,
                                     filters=8,
                                     kernel_size=3,
                                     strides=1,
                                     padding='same',
                                     activation=tf.nn.relu)
                print(x)
                x = tf.layers.conv2d(x,
                                     filters=self.img_channels,
                                     kernel_size=1,
                                     strides=1,
                                     padding='same',
                                     activation=tf.nn.relu)
            print(x)
            print()
            return x
    def __init__(self,
                 sess,
                 input_shape,
                 num_actions,
                 reuse=False,
                 is_training=True,
                 name='train'):
        super().__init__(sess, reuse)
        self.initial_state = []
        with tf.name_scope(name + "policy_input"):
            self.X_input = tf.placeholder(tf.uint8, input_shape)

        with tf.variable_scope("policy", reuse=reuse):
            conv1 = conv2d('conv1',
                           tf.cast(self.X_input, tf.float32) / 255.,
                           num_filters=32,
                           kernel_size=(8, 8),
                           padding='VALID',
                           stride=(4, 4),
                           initializer=orthogonal_initializer(np.sqrt(2)),
                           activation=tf.nn.relu,
                           is_training=is_training)

            conv2 = conv2d('conv2',
                           conv1,
                           num_filters=64,
                           kernel_size=(4, 4),
                           padding='VALID',
                           stride=(2, 2),
                           initializer=orthogonal_initializer(np.sqrt(2)),
                           activation=tf.nn.relu,
                           is_training=is_training)

            conv3 = conv2d('conv3',
                           conv2,
                           num_filters=64,
                           kernel_size=(3, 3),
                           padding='VALID',
                           stride=(1, 1),
                           initializer=orthogonal_initializer(np.sqrt(2)),
                           activation=tf.nn.relu,
                           is_training=is_training)

            conv3_flattened = flatten(conv3)

            fc4 = dense('fc4',
                        conv3_flattened,
                        output_dim=512,
                        initializer=orthogonal_initializer(np.sqrt(2)),
                        activation=tf.nn.relu,
                        is_training=is_training)

            self.policy_logits = dense('policy_logits',
                                       fc4,
                                       output_dim=num_actions,
                                       initializer=orthogonal_initializer(
                                           np.sqrt(1.0)),
                                       is_training=is_training)

            self.value_function = dense('value_function',
                                        fc4,
                                        output_dim=1,
                                        initializer=orthogonal_initializer(
                                            np.sqrt(1.0)),
                                        is_training=is_training)

            with tf.name_scope('value'):
                self.value_s = self.value_function[:, 0]

            with tf.name_scope('action'):
                self.action_s = noise_and_argmax(self.policy_logits)