class BDCGAN_Semi(object):
    def __init__(self,
                 x_dim,
                 z_dim,
                 dataset_size,
                 batch_size=64,
                 gf_dim=64,
                 df_dim=64,
                 prior_std=1.0,
                 J=1,
                 M=1,
                 num_classes=1,
                 eta=2e-4,
                 num_layers=4,
                 alpha=0.01,
                 lr=0.0002,
                 optimizer='adam',
                 wasserstein=False,
                 ml=False,
                 J_d=None):

        assert len(x_dim) == 3, "invalid image dims"
        c_dim = x_dim[2]
        self.is_grayscale = (c_dim == 1)
        self.optimizer = optimizer.lower()
        self.dataset_size = dataset_size
        self.batch_size = batch_size

        self.K = num_classes
        self.x_dim = x_dim
        self.z_dim = z_dim

        self.gf_dim = gf_dim
        self.df_dim = df_dim
        self.c_dim = c_dim
        self.lr = lr

        # Bayes
        self.prior_std = prior_std
        self.num_gen = J
        self.num_disc = J_d if J_d is not None else 1
        self.num_mcmc = M
        self.eta = eta
        self.alpha = alpha
        # ML
        self.ml = ml
        if self.ml:
            assert self.num_gen == 1 and self.num_disc == 1 and self.num_mcmc == 1, "invalid settings for ML training"

        self.noise_std = np.sqrt(2 * self.alpha * self.eta)

        def get_strides(num_layers, num_pool):
            interval = int(math.floor(num_layers / float(num_pool)))
            strides = np.array([1] * num_layers)
            strides[0:interval * num_pool:interval] = 2
            return strides

        self.num_pool = 4
        self.max_num_dfs = 512
        self.gen_strides = get_strides(num_layers, self.num_pool)
        self.disc_strides = self.gen_strides
        num_dfs = np.cumprod(np.array([self.df_dim] +
                                      list(self.disc_strides)))[:-1]
        num_dfs[num_dfs >= self.max_num_dfs] = self.max_num_dfs  # memory
        self.num_dfs = list(num_dfs)
        self.num_gfs = self.num_dfs[::-1]

        self.construct_from_hypers(gen_strides=self.gen_strides,
                                   disc_strides=self.disc_strides,
                                   num_gfs=self.num_gfs,
                                   num_dfs=self.num_dfs)

        self.build_bgan_graph()
        self.build_test_graph()

    def construct_from_hypers(self,
                              gen_kernel_size=5,
                              gen_strides=[2, 2, 2, 2],
                              disc_kernel_size=5,
                              disc_strides=[2, 2, 2, 2],
                              num_dfs=None,
                              num_gfs=None):

        self.d_batch_norm = AttributeDict([
            ("d_bn%i" % dbn_i, batch_norm(name='d_bn%i' % dbn_i))
            for dbn_i in range(len(disc_strides))
        ])
        self.sup_d_batch_norm = AttributeDict([
            ("sd_bn%i" % dbn_i, batch_norm(name='sup_d_bn%i' % dbn_i))
            for dbn_i in range(5)
        ])
        self.g_batch_norm = AttributeDict([
            ("g_bn%i" % gbn_i, batch_norm(name='g_bn%i' % gbn_i))
            for gbn_i in range(len(gen_strides))
        ])

        if num_dfs is None:
            num_dfs = [
                self.df_dim, self.df_dim * 2, self.df_dim * 4, self.df_dim * 8
            ]

        if num_gfs is None:
            num_gfs = [
                self.gf_dim * 8, self.gf_dim * 4, self.gf_dim * 2, self.gf_dim
            ]

        assert len(gen_strides) == len(num_gfs), "invalid hypers!"
        assert len(disc_strides) == len(num_dfs), "invalid hypers!"

        s_h, s_w = self.x_dim[0], self.x_dim[1]
        ks = gen_kernel_size
        self.gen_output_dims = OrderedDict()
        self.gen_weight_dims = OrderedDict()
        num_gfs = num_gfs + [self.c_dim]
        self.gen_kernel_sizes = [ks]
        for layer in range(len(gen_strides))[::-1]:
            self.gen_output_dims["g_h%i_out" % (layer + 1)] = (s_h, s_w)
            assert gen_strides[layer] <= 2, "invalid stride"
            assert ks % 2 == 1, "invalid kernel size"
            self.gen_weight_dims["g_h%i_W" %
                                 (layer + 1)] = (ks, ks, num_gfs[layer + 1],
                                                 num_gfs[layer])
            self.gen_weight_dims["g_h%i_b" % (layer + 1)] = (num_gfs[layer +
                                                                     1], )
            s_h, s_w = conv_out_size(s_h, gen_strides[layer]), conv_out_size(
                s_w, gen_strides[layer])
            ks = kernel_sizer(ks, gen_strides[layer])
            self.gen_kernel_sizes.append(ks)

        self.gen_weight_dims.update(
            OrderedDict([("g_h0_lin_W", (self.z_dim, num_gfs[0] * s_h * s_w)),
                         ("g_h0_lin_b", (num_gfs[0] * s_h * s_w, ))]))
        self.gen_output_dims["g_h0_out"] = (s_h, s_w)

        self.disc_weight_dims = OrderedDict()
        s_h, s_w = self.x_dim[0], self.x_dim[1]
        num_dfs = [self.c_dim] + num_dfs
        ks = disc_kernel_size
        self.disc_kernel_sizes = [ks]
        for layer in range(len(disc_strides)):
            assert disc_strides[layer] <= 2, "invalid stride"
            assert ks % 2 == 1, "invalid kernel size"
            self.disc_weight_dims["d_h%i_W" % layer] = (ks, ks, num_dfs[layer],
                                                        num_dfs[layer + 1])
            self.disc_weight_dims["d_h%i_b" % layer] = (num_dfs[layer + 1], )
            s_h, s_w = conv_out_size(s_h, disc_strides[layer]), conv_out_size(
                s_w, disc_strides[layer])
            ks = kernel_sizer(ks, disc_strides[layer])
            self.disc_kernel_sizes.append(ks)

        self.disc_weight_dims.update(
            OrderedDict([("d_h_end_lin_W", (num_dfs[-1] * s_h * s_w,
                                            num_dfs[-1])),
                         ("d_h_end_lin_b", (num_dfs[-1], )),
                         ("d_h_out_lin_W", (num_dfs[-1], self.K)),
                         ("d_h_out_lin_b", (self.K, ))]))

        for k, v in self.gen_output_dims.items():
            print "%s: %s" % (k, v)
        print '****'
        for k, v in self.gen_weight_dims.items():
            print "%s: %s" % (k, v)
        print '****'
        for k, v in self.disc_weight_dims.items():
            print "%s: %s" % (k, v)

    def construct_nets(self):

        self.num_disc_layers = 5
        self.num_gen_layers = 5
        self.d_batch_norm = AttributeDict([
            ("d_bn%i" % dbn_i, batch_norm(name='d_bn%i' % dbn_i))
            for dbn_i in range(self.num_disc_layers)
        ])
        self.sup_d_batch_norm = AttributeDict([
            ("sd_bn%i" % dbn_i, batch_norm(name='sup_d_bn%i' % dbn_i))
            for dbn_i in range(self.num_disc_layers)
        ])
        self.g_batch_norm = AttributeDict([
            ("g_bn%i" % gbn_i, batch_norm(name='g_bn%i' % gbn_i))
            for gbn_i in range(self.num_gen_layers)
        ])

        s_h, s_w = self.x_dim[0], self.x_dim[1]
        s_h2, s_w2 = conv_out_size(s_h, 2), conv_out_size(s_w, 2)
        s_h4, s_w4 = conv_out_size(s_h2, 2), conv_out_size(s_w2, 2)
        s_h8, s_w8 = conv_out_size(s_h4, 2), conv_out_size(s_w4, 2)
        s_h16, s_w16 = conv_out_size(s_h8, 2), conv_out_size(s_w8, 2)

        self.gen_output_dims = OrderedDict([("g_h0_out", (s_h16, s_w16)),
                                            ("g_h1_out", (s_h8, s_w8)),
                                            ("g_h2_out", (s_h4, s_w4)),
                                            ("g_h3_out", (s_h2, s_w2)),
                                            ("g_h4_out", (s_h, s_w))])

        self.gen_weight_dims = OrderedDict([
            ("g_h0_lin_W", (self.z_dim, self.gf_dim * 8 * s_h16 * s_w16)),
            ("g_h0_lin_b", (self.gf_dim * 8 * s_h16 * s_w16, )),
            ("g_h1_W", (5, 5, self.gf_dim * 4, self.gf_dim * 8)),
            ("g_h1_b", (self.gf_dim * 4, )),
            ("g_h2_W", (5, 5, self.gf_dim * 2, self.gf_dim * 4)),
            ("g_h2_b", (self.gf_dim * 2, )),
            ("g_h3_W", (5, 5, self.gf_dim * 1, self.gf_dim * 2)),
            ("g_h3_b", (self.gf_dim * 1, )),
            ("g_h4_W", (5, 5, self.c_dim, self.gf_dim * 1)),
            ("g_h4_b", (self.c_dim, ))
        ])

        self.disc_weight_dims = OrderedDict([
            ("d_h0_W", (5, 5, self.c_dim, self.df_dim)),
            ("d_h0_b", (self.df_dim, )),
            ("d_h1_W", (5, 5, self.df_dim, self.df_dim * 2)),
            ("d_h1_b", (self.df_dim * 2, )),
            ("d_h2_W", (5, 5, self.df_dim * 2, self.df_dim * 4)),
            ("d_h2_b", (self.df_dim * 4, )),
            ("d_h3_W", (5, 5, self.df_dim * 4, self.df_dim * 8)),
            ("d_h3_b", (self.df_dim * 8, )),
            ("d_h_end_lin_W", (self.df_dim * 8 * s_h16 * s_w16,
                               self.df_dim * 4)),
            ("d_h_end_lin_b", (self.df_dim * 4, )),
            ("d_h_out_lin_W", (self.df_dim * 4, self.K)),
            ("d_h_out_lin_b", (self.K, ))
        ])

    def _get_optimizer(self, lr):
        if self.optimizer == 'adam':
            return tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5)
        elif self.optimizer == 'sgd':
            return tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.5)
        else:
            raise ValueError("Optimizer must be either 'adam' or 'sgd'")

    def initialize_wgts(self, scope_str):

        if scope_str == "generator":
            weight_dims = self.gen_weight_dims
            numz = self.num_gen
        elif scope_str == "discriminator":
            weight_dims = self.disc_weight_dims
            numz = self.num_disc
        else:
            raise RuntimeError("invalid scope!")

        param_list = []
        with tf.variable_scope(scope_str) as scope:
            for zi in xrange(numz):
                for m in xrange(self.num_mcmc):
                    wgts_ = AttributeDict()
                    for name, shape in weight_dims.iteritems():
                        wgts_[name] = tf.get_variable(
                            "%s_%04d_%04d" % (name, zi, m),
                            shape,
                            initializer=tf.random_normal_initializer(
                                stddev=0.02))
                    param_list.append(wgts_)
            return param_list

    def build_bgan_graph(self):

        self.inputs = tf.placeholder(tf.float32,
                                     [self.batch_size] + self.x_dim,
                                     name='real_images')

        self.labeled_inputs = tf.placeholder(tf.float32,
                                             [self.batch_size] + self.x_dim,
                                             name='real_images_w_labels')

        self.labels = tf.placeholder(tf.float32, [self.batch_size, self.K],
                                     name='real_targets')

        self.z = tf.placeholder(tf.float32,
                                [self.batch_size, self.z_dim, self.num_gen],
                                name='z')
        self.z_sampler = tf.placeholder(tf.float32,
                                        [self.batch_size, self.z_dim],
                                        name='z_sampler')

        # initialize generator weights
        self.gen_param_list = self.initialize_wgts("generator")
        self.disc_param_list = self.initialize_wgts("discriminator")
        ### build discrimitive losses and optimizers
        # prep optimizer args
        self.d_semi_learning_rate = tf.placeholder(tf.float32, shape=[])

        # compile all disciminative weights
        t_vars = tf.trainable_variables()
        self.d_vars = []
        for di in xrange(self.num_disc):
            for m in xrange(self.num_mcmc):
                self.d_vars.append([
                    var for var in t_vars
                    if 'd_' in var.name and "_%04d_%04d" % (di, m) in var.name
                ])

        ### build disc losses and optimizers
        self.d_losses, self.d_optims_semi, self.d_optims_semi_adam = [], [], []
        for di, disc_params in enumerate(self.disc_param_list):

            d_probs, d_logits, _ = self.discriminator(self.inputs, self.K,
                                                      disc_params)

            d_loss_real = -tf.reduce_mean(tf.reduce_logsumexp(d_logits, 1)) +\
            tf.reduce_mean(tf.nn.softplus(tf.reduce_logsumexp(d_logits, 1)))

            d_loss_fakes = []
            for gi, gen_params in enumerate(self.gen_param_list):
                d_probs_, d_logits_, _ = self.discriminator(
                    self.generator(self.z[:, :, gi % self.num_gen],
                                   gen_params), self.K, disc_params)
                d_loss_fake_ = tf.reduce_mean(
                    tf.nn.softplus(tf.reduce_logsumexp(d_logits_, 1)))
                d_loss_fakes.append(d_loss_fake_)

            d_sup_probs, d_sup_logits, _ = self.discriminator(
                self.labeled_inputs, self.K, disc_params)
            d_loss_sup = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(logits=d_sup_logits,
                                                        labels=self.labels))
            d_losses_semi = []
            for d_loss_fake_ in d_loss_fakes:
                d_loss_semi_ = d_loss_sup + d_loss_real * float(
                    self.num_gen) + d_loss_fake_
                if not self.ml:
                    d_loss_semi_ += self.disc_prior(
                        disc_params) + self.disc_noise(disc_params)
                d_losses_semi.append(tf.reshape(d_loss_semi_, [1]))

            d_loss_semi = tf.reduce_logsumexp(tf.concat(d_losses_semi, 0))
            self.d_losses.append(d_loss_semi)
            d_opt_semi = self._get_optimizer(self.d_semi_learning_rate)
            self.d_optims_semi.append(
                d_opt_semi.minimize(d_loss_semi, var_list=self.d_vars[di]))
            d_opt_semi_adam = tf.train.AdamOptimizer(
                learning_rate=self.d_semi_learning_rate, beta1=0.5)
            self.d_optims_semi_adam.append(
                d_opt_semi_adam.minimize(d_loss_semi,
                                         var_list=self.d_vars[di]))

        ### build generative losses and optimizers
        self.g_learning_rate = tf.placeholder(tf.float32, shape=[])
        self.g_vars = []
        for gi in xrange(self.num_gen):
            for m in xrange(self.num_mcmc):
                self.g_vars.append([
                    var for var in t_vars
                    if 'g_' in var.name and "_%04d_%04d" % (gi, m) in var.name
                ])

        self.g_losses, self.g_optims_semi, self.g_optims_semi_adam = [], [], []
        for gi, gen_params in enumerate(self.gen_param_list):

            gi_losses = []
            for disc_params in self.disc_param_list:
                d_probs_, d_logits_, d_features_fake = self.discriminator(
                    self.generator(self.z[:, :, gi % self.num_gen],
                                   gen_params), self.K, disc_params)
                _, _, d_features_real = self.discriminator(
                    self.inputs, self.K, disc_params)
                g_loss_ = -tf.reduce_mean(tf.reduce_logsumexp(d_logits_, 1)) +\
                tf.reduce_mean(tf.nn.softplus(tf.reduce_logsumexp(d_logits_, 1))) # not needed?!
                g_loss_ += tf.reduce_mean(
                    huber_loss(d_features_real[-1], d_features_fake[-1]))
                if not self.ml:
                    g_loss_ += self.gen_prior(gen_params) + self.gen_noise(
                        gen_params)
                gi_losses.append(tf.reshape(g_loss_, [1]))

            g_loss = tf.reduce_logsumexp(tf.concat(gi_losses, 0))
            self.g_losses.append(g_loss)
            g_opt = self._get_optimizer(self.g_learning_rate)
            self.g_optims_semi.append(
                g_opt.minimize(g_loss, var_list=self.g_vars[gi]))
            g_opt_adam = tf.train.AdamOptimizer(
                learning_rate=self.g_learning_rate, beta1=0.5)
            self.g_optims_semi_adam.append(
                g_opt_adam.minimize(g_loss, var_list=self.g_vars[gi]))

        ### build samplers
        self.gen_samplers = []
        for gi, gen_params in enumerate(self.gen_param_list):
            self.gen_samplers.append(self.generator(self.z_sampler,
                                                    gen_params))

        ### build vanilla supervised loss
        self.lbls = tf.placeholder(tf.float32, [self.batch_size, self.K],
                                   name='real_sup_targets')

        self.S, self.S_logits = self.sup_discriminator(self.inputs, self.K)
        self.s_loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(logits=self.S_logits,
                                                    labels=self.lbls))
        t_vars = tf.trainable_variables()
        self.sup_vars = [var for var in t_vars if 'sup_' in var.name]
        supervised_lr = 0.05 * self.lr
        s_opt = self._get_optimizer(supervised_lr)
        self.s_optim = s_opt.minimize(self.s_loss, var_list=self.sup_vars)
        s_opt_adam = tf.train.AdamOptimizer(learning_rate=supervised_lr,
                                            beta1=0.5)
        self.s_optim_adam = s_opt_adam.minimize(self.s_loss,
                                                var_list=self.sup_vars)

    def build_test_graph(self):

        self.test_inputs = tf.placeholder(tf.float32,
                                          [self.batch_size] + self.x_dim,
                                          name='real_test_images')

        self.test_d_probs, self.test_d_logits = [], []
        for disc_params in self.disc_param_list:
            test_d_probs_, test_d_logits_, _ = self.discriminator(
                self.test_inputs, self.K, disc_params, train=False)
            self.test_d_probs.append(test_d_probs_)
            self.test_d_logits.append(test_d_logits_)

        # build standard purely supervised losses and optimizers
        self.test_s_probs, self.test_s_logits = self.sup_discriminator(
            self.test_inputs, self.K, reuse=True)

    def sup_discriminator(self, image, K, reuse=False):
        # TODO collapse this into disc
        with tf.variable_scope("sup_discriminator") as scope:
            if reuse:
                scope.reuse_variables()

            h0 = lrelu(conv2d(image, self.df_dim, name='sup_h0_conv'))
            h1 = lrelu(
                self.sup_d_batch_norm.sd_bn1(
                    conv2d(h0, self.df_dim * 2, name='sup_h1_conv')))
            h2 = lrelu(
                self.sup_d_batch_norm.sd_bn2(
                    conv2d(h1, self.df_dim * 4, name='sup_h2_conv')))
            h3 = lrelu(
                self.sup_d_batch_norm.sd_bn3(
                    conv2d(h2, self.df_dim * 8, name='sup_h3_conv')))
            h4 = linear(tf.reshape(h3, [self.batch_size, -1]), K, 'sup_h3_lin')
            return tf.nn.softmax(h4), h4

    def discriminator(self, image, K, disc_params, train=True):

        with tf.variable_scope("discriminator") as scope:

            h = image
            for layer in range(len(self.disc_strides)):
                if layer == 0:
                    h = lrelu(
                        conv2d(h,
                               self.disc_weight_dims["d_h%i_W" % layer][-1],
                               name='d_h%i_conv' % layer,
                               k_h=self.disc_kernel_sizes[layer],
                               k_w=self.disc_kernel_sizes[layer],
                               d_h=self.disc_strides[layer],
                               d_w=self.disc_strides[layer],
                               w=disc_params["d_h%i_W" % layer],
                               biases=disc_params["d_h%i_b" % layer]))
                else:
                    h = lrelu(self.d_batch_norm["d_bn%i" % layer](conv2d(
                        h,
                        self.disc_weight_dims["d_h%i_W" % layer][-1],
                        name='d_h%i_conv' % layer,
                        k_h=self.disc_kernel_sizes[layer],
                        k_w=self.disc_kernel_sizes[layer],
                        d_h=self.disc_strides[layer],
                        d_w=self.disc_strides[layer],
                        w=disc_params["d_h%i_W" % layer],
                        biases=disc_params["d_h%i_b" % layer]),
                                                                  train=train))

            h_end = lrelu(
                linear(tf.reshape(h, [self.batch_size, -1]),
                       self.df_dim * 4,
                       "d_h_end_lin",
                       matrix=disc_params.d_h_end_lin_W,
                       bias=disc_params.d_h_end_lin_b))  # for feature norm
            h_out = linear(h_end,
                           K,
                           'd_h_out_lin',
                           matrix=disc_params.d_h_out_lin_W,
                           bias=disc_params.d_h_out_lin_b)

            return tf.nn.softmax(h_out), h_out, [h_end]

    def generator(self, z, gen_params):

        with tf.variable_scope("generator") as scope:

            h = linear(z,
                       self.gen_weight_dims["g_h0_lin_W"][-1],
                       'g_h0_lin',
                       matrix=gen_params.g_h0_lin_W,
                       bias=gen_params.g_h0_lin_b)
            h = tf.nn.relu(self.g_batch_norm.g_bn0(h))

            h = tf.reshape(h, [
                self.batch_size, self.gen_output_dims["g_h0_out"][0],
                self.gen_output_dims["g_h0_out"][1], -1
            ])

            for layer in range(1, len(self.gen_strides) + 1):

                out_shape = [
                    self.batch_size,
                    self.gen_output_dims["g_h%i_out" % layer][0],
                    self.gen_output_dims["g_h%i_out" % layer][1],
                    self.gen_weight_dims["g_h%i_W" % layer][-2]
                ]

                h = deconv2d(h,
                             out_shape,
                             k_h=self.gen_kernel_sizes[layer - 1],
                             k_w=self.gen_kernel_sizes[layer - 1],
                             d_h=self.gen_strides[layer - 1],
                             d_w=self.gen_strides[layer - 1],
                             name='g_h%i' % layer,
                             w=gen_params["g_h%i_W" % layer],
                             biases=gen_params["g_h%i_b" % layer])
                if layer < len(self.gen_strides):
                    h = tf.nn.relu(self.g_batch_norm["g_bn%i" % layer](h))

            return tf.nn.tanh(h)

    def gen_prior(self, gen_params):
        with tf.variable_scope("generator") as scope:
            prior_loss = 0.0
            for var in gen_params.values():
                nn = tf.divide(var, self.prior_std)
                prior_loss += tf.reduce_mean(tf.multiply(nn, nn))

        prior_loss /= self.dataset_size

        return prior_loss

    def gen_noise(self, gen_params):
        with tf.variable_scope("generator") as scope:
            noise_loss = 0.0
            for name, var in gen_params.iteritems():
                noise_ = tf.contrib.distributions.Normal(
                    mu=0., sigma=self.noise_std * tf.ones(var.get_shape()))
                noise_loss += tf.reduce_sum(var * noise_.sample())
        noise_loss /= self.dataset_size
        return noise_loss

    def disc_prior(self, disc_params):
        with tf.variable_scope("discriminator") as scope:
            prior_loss = 0.0
            for var in disc_params.values():
                nn = tf.divide(var, self.prior_std)
                prior_loss += tf.reduce_mean(tf.multiply(nn, nn))

        prior_loss /= self.dataset_size

        return prior_loss

    def disc_noise(self, disc_params):
        with tf.variable_scope("discriminator") as scope:
            noise_loss = 0.0
            for var in disc_params.values():
                noise_ = tf.contrib.distributions.Normal(
                    mu=0., sigma=self.noise_std * tf.ones(var.get_shape()))
                noise_loss += tf.reduce_sum(var * noise_.sample())
        noise_loss /= self.dataset_size
        return noise_loss
Пример #2
0
class BDCGAN_Semi_3d(object):

    def __init__(self, x_dim, z_dim, dataset_size, batch_size=64, gf_dim=64, df_dim=64,
                 prior_std=1.0, J=1, M=1, num_classes=1, eta=1, num_layers=4,
                 alpha=0.01, lr=0.0002, optimizer='adam', wasserstein=False,
                 ml=False, J_d=None):  # eta=2e-4,

        print("ml = ", ml)

        self.optimizer = optimizer.lower()
        self.dataset_size = dataset_size
        self.batch_size = batch_size

        self.K = num_classes
        self.x_dim = x_dim
        self.z_dim = z_dim    # generated sample's dim

        self.gf_dim = gf_dim  # ?? what is df_dim = 64 ?
        self.df_dim = df_dim
        self.c_dim = x_dim[3] # x_dim = [x, y, z, c]
        self.is_grayscale = (self.c_dim == 1)
        self.lr = lr

        # Bayes
        self.prior_std = prior_std
        self.num_gen = J     # what is num_gen ??
        self.num_disc = J_d if J_d is not None else 1
        self.num_mcmc = M
        self.eta = eta       # not required in variational inference and MC dropout
        self.alpha = alpha   # not required in variational inference and MC dropout

        # ML
        self.ml = ml
        if self.ml:
            assert self.num_gen == 1 and self.num_disc == 1 and self.num_mcmc == 1, "invalid settings for ML training"

        self.noise_std = 10  # np.sqrt(2 * self.alpha * self.eta)\


        def get_strides(num_layers, num_pool):
            interval = int(math.floor(num_layers / float(num_pool)))
            strides = np.array([1] * num_layers)
            strides[0:interval * num_pool:interval] = 2
            return strides

        self.num_pool = 4
        self.max_num_dfs = 1024   # default - 512
        self.gen_strides = get_strides(num_layers, self.num_pool)
        self.disc_strides = self.gen_strides
        num_dfs = np.cumprod(np.array([self.df_dim] + list(self.disc_strides)))[:-1]
        num_dfs[num_dfs >= self.max_num_dfs] = self.max_num_dfs  # memory
        self.num_dfs = list(num_dfs)
        self.num_gfs = self.num_dfs[::-1]

        self.construct_from_hypers(gen_strides=self.gen_strides, disc_strides=self.disc_strides,
                                   num_gfs=self.num_gfs, num_dfs=self.num_dfs)

        self.build_bgan_graph()
        self.build_test_graph()

    def construct_from_hypers(self, gen_kernel_size=5, gen_strides=[2, 2, 2, 2],
                              disc_kernel_size=5, disc_strides=[2, 2, 2, 2],
                              num_dfs=None, num_gfs=None):

        self.d_batch_norm = AttributeDict(
            [("d_bn%i" % dbn_i, batch_norm(name='d_bn%i' % dbn_i)) for dbn_i in range(len(disc_strides))])
        self.sup_d_batch_norm = AttributeDict(
            [("sd_bn%i" % dbn_i, batch_norm(name='sup_d_bn%i' % dbn_i)) for dbn_i in range(5)])
        self.g_batch_norm = AttributeDict(
            [("g_bn%i" % gbn_i, batch_norm(name='g_bn%i' % gbn_i)) for gbn_i in range(len(gen_strides))])

        if num_dfs is None:
            num_dfs = [self.df_dim, self.df_dim * 2, self.df_dim * 4, self.df_dim * 8]

        if num_gfs is None:
            num_gfs = [self.gf_dim * 8, self.gf_dim * 4, self.gf_dim * 2, self.gf_dim]

        assert len(gen_strides) == len(num_gfs), "invalid hypers!"
        assert len(disc_strides) == len(num_dfs), "invalid hypers!"

        s_h, s_w = self.x_dim[0], self.x_dim[1]
        ks = gen_kernel_size
        self.gen_output_dims = OrderedDict()
        self.gen_weight_dims = OrderedDict()
        num_gfs = num_gfs + [self.c_dim]
        self.gen_kernel_sizes = [ks]

        for layer in range(len(gen_strides))[::-1]:
            self.gen_output_dims["g_h%i_out" % (layer + 1)] = (s_h, s_w)
            assert gen_strides[layer] <= 2, "invalid stride"
            assert ks % 2 == 1, "invalid kernel size"

            self.gen_weight_dims["g_h%i_W" % (layer + 1)] = (ks, ks, num_gfs[layer + 1], num_gfs[layer])
            self.gen_weight_dims["g_h%i_b" % (layer + 1)] = (num_gfs[layer + 1],)
            s_h, s_w = conv_out_size(s_h, gen_strides[layer]), conv_out_size(s_w, gen_strides[layer])
            ks = kernel_sizer(ks, gen_strides[layer])
            self.gen_kernel_sizes.append(ks)

        self.gen_weight_dims.update(OrderedDict([("g_h0_lin_W", (self.z_dim, num_gfs[0] * s_h * s_w)),
                                                 ("g_h0_lin_b", (num_gfs[0] * s_h * s_w,))]))
        self.gen_output_dims["g_h0_out"] = (s_h, s_w)

        self.disc_weight_dims = OrderedDict()
        s_h, s_w = self.x_dim[0], self.x_dim[1]
        num_dfs = [self.c_dim] + num_dfs
        ks = disc_kernel_size
        self.disc_kernel_sizes = [ks]

        for layer in range(len(disc_strides)):
            assert disc_strides[layer] <= 2, "invalid stride"
            assert ks % 2 == 1, "invalid kernel size"

            self.disc_weight_dims["d_h%i_W" % layer] = (ks, ks, num_dfs[layer], num_dfs[layer + 1])
            self.disc_weight_dims["d_h%i_b" % layer] = (num_dfs[layer + 1],)
            s_h, s_w = conv_out_size(s_h, disc_strides[layer]), conv_out_size(s_w, disc_strides[layer])
            ks = kernel_sizer(ks, disc_strides[layer])
            self.disc_kernel_sizes.append(ks)

        self.disc_weight_dims.update(OrderedDict([("d_h_end_lin_W", (num_dfs[-1] * s_h * s_w, num_dfs[-1])),
                                                  ("d_h_end_lin_b", (num_dfs[-1],)),
                                                  ("d_h_out_lin_W", (num_dfs[-1], self.K)),
                                                  ("d_h_out_lin_b", (self.K,))]))

        for k, v in self.gen_output_dims.items():
            print("gen_output_dims - %s: %s" % (k, v))
        print('****')
        for k, v in self.gen_weight_dims.items():
            print("gen_weight_dims - %s: %s" % (k, v))
        print('****')
        for k, v in self.disc_weight_dims.items():
            print("dics_weight_dims - %s: %s" % (k, v))

    def construct_nets(self):

        self.num_disc_layers = 5
        self.num_gen_layers = 5

        self.d_batch_norm = AttributeDict(
            [("d_bn%i" % dbn_i, batch_norm(name='d_bn%i' % dbn_i)) for dbn_i in range(self.num_disc_layers)])
        self.sup_d_batch_norm = AttributeDict(
            [("sd_bn%i" % dbn_i, batch_norm(name='sup_d_bn%i' % dbn_i)) for dbn_i in range(self.num_disc_layers)])
        self.g_batch_norm = AttributeDict(
            [("g_bn%i" % gbn_i, batch_norm(name='g_bn%i' % gbn_i)) for gbn_i in range(self.num_gen_layers)])

        s_h, s_w = self.x_dim[0], self.x_dim[1]
        s_h2, s_w2 = conv_out_size(s_h, 2), conv_out_size(s_w, 2)
        s_h4, s_w4 = conv_out_size(s_h2, 2), conv_out_size(s_w2, 2)
        s_h8, s_w8 = conv_out_size(s_h4, 2), conv_out_size(s_w4, 2)
        s_h16, s_w16 = conv_out_size(s_h8, 2), conv_out_size(s_w8, 2)

        self.gen_output_dims = OrderedDict([("g_h0_out", (s_h16, s_w16)),
                                            ("g_h1_out", (s_h8, s_w8)),
                                            ("g_h2_out", (s_h4, s_w4)),
                                            ("g_h3_out", (s_h2, s_w2)),
                                            ("g_h4_out", (s_h, s_w))])

        self.gen_weight_dims = OrderedDict([("g_h0_lin_W", (self.z_dim, self.gf_dim * 8 * s_h16 * s_w16)),
                                            ("g_h0_lin_b", (self.gf_dim * 8 * s_h16 * s_w16,)),
                                            ("g_h1_W", (5, 5, self.gf_dim * 4, self.gf_dim * 8)),
                                            ("g_h1_b", (self.gf_dim * 4,)),
                                            ("g_h2_W", (5, 5, self.gf_dim * 2, self.gf_dim * 4)),
                                            ("g_h2_b", (self.gf_dim * 2,)),
                                            ("g_h3_W", (5, 5, self.gf_dim * 1, self.gf_dim * 2)),
                                            ("g_h3_b", (self.gf_dim * 1,)),
                                            ("g_h4_W", (5, 5, self.c_dim, self.gf_dim * 1)),
                                            ("g_h4_b", (self.c_dim,))])

        self.disc_weight_dims = OrderedDict([("d_h0_W", (5, 5, self.c_dim, self.df_dim)),
                                             ("d_h0_b", (self.df_dim,)),
                                             ("d_h1_W", (5, 5, self.df_dim, self.df_dim * 2)),
                                             ("d_h1_b", (self.df_dim * 2,)),
                                             ("d_h2_W", (5, 5, self.df_dim * 2, self.df_dim * 4)),
                                             ("d_h2_b", (self.df_dim * 4,)),
                                             ("d_h3_W", (5, 5, self.df_dim * 4, self.df_dim * 8)),
                                             ("d_h3_b", (self.df_dim * 8,)),
                                             ("d_h_end_lin_W", (self.df_dim * 8 * s_h16 * s_w16, self.df_dim * 4)),
                                             ("d_h_end_lin_b", (self.df_dim * 4,)),
                                             ("d_h_out_lin_W", (self.df_dim * 4, self.K)),
                                             ("d_h_out_lin_b", (self.K,))])

    def _get_optimizer(self, lr):

        if self.optimizer == 'adam':
            return tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5)
        elif self.optimizer == 'sgd':
            return tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.5)
        else:
            raise ValueError("Optimizer must be either 'adam' or 'sgd'")

    def initialize_wgts(self, scope_str):

        if scope_str == "generator":
            weight_dims = self.gen_weight_dims
            numz = self.num_gen
        elif scope_str == "discriminator":
            weight_dims = self.disc_weight_dims
            numz = self.num_disc
        else:
            raise RuntimeError("invalid scope!")

        param_list = []
        with tf.variable_scope(scope_str) as scope:  # iterated J (numz / num_gen) x num_mcmc = 20
            for zi in range(numz):  # numz: num_gen / num_disc
                for m in range(self.num_mcmc):
                    wgts_ = AttributeDict()
                    for name, shape in weight_dims.items():
                        wgts_[name] = tf.get_variable("%s_%04d_%04d" % (name, zi, m), shape,
                                                      initializer=tf.random_normal_initializer(stddev=0.02))
                    param_list.append(wgts_)

            return param_list

    def build_bgan_graph(self):

        # unsupervised images from data distribution
        self.inputs = tf.placeholder(tf.float32, [self.batch_size] + self.x_dim, name='real_images')

        # for discrinimator: from supervised batch images
        self.labeled_inputs = tf.placeholder(tf.float32, [self.batch_size] + self.x_dim, name='real_images_w_labels')
        self.labels = tf.placeholder(tf.float32, [self.batch_size, self.K], name='real_targets')

        # for generator
        self.z = tf.placeholder(tf.float32, [self.batch_size, self.z_dim, self.num_gen], name='z')  # [64, 100, 10]
        self.z_sampler = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z_sampler')

        # initialize generator weights
        self.gen_param_list = self.initialize_wgts("generator")  # num_gen * num_mcmc - list
        self.disc_param_list = self.initialize_wgts("discriminator")  # num_disc * num_mcmc

        ############################ build discrimitive losses and optimizers ##########################################

        self.d_semi_learning_rate = tf.placeholder(tf.float32, shape=[])

        t_vars = tf.trainable_variables()  # compile all disciminative weights  # returns a list of trainable variables
        self.d_vars = []
        for di in range(self.num_disc):
            for m in range(self.num_mcmc):
                self.d_vars.append([var for var in t_vars if 'd_' in var.name and "_%04d_%04d" % (di, m) in var.name])

        self.d_losses, self.d_optims_semi, self.d_optims_semi_adam = [], [], []  ### self.d_optims_semi is user specified optimizer

        for di, disc_params in enumerate(self.disc_param_list):  # with len(disc_param_list) > 1, the first discrinimator could be reuse = False, however, the second should use the variables

            # Part I: real ####################
            # d_probs = softmax(d_logits), d_logits = linear(pre-layer)
            d_probs_real, d_logits_real, _ = self.discriminator(self.inputs, self.K, disc_params, reuse=tf.AUTO_REUSE)

            # JT-0228: d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, labels=tf.ones_like(d_probs_real)))
            d_loss_real = - tf.reduce_mean(tf.reduce_logsumexp(d_logits_real, 1)) \
                          + tf.reduce_mean(tf.nn.softplus(tf.reduce_logsumexp(d_logits_real, 1)))

            # Part II: fake ####################
            d_loss_fakes = []
            for gi, gen_params in enumerate(self.gen_param_list):  # iterate num_gen * num_mcmc times

                d_probs_fake, d_logits_fake, _ = self.discriminator(
                    self.generator(self.z[:, :, gi % self.num_gen], gen_params), self.K, disc_params, reuse=True)

                # JT-0228: d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_probs_fake)))
                d_loss_fake = tf.reduce_mean(tf.nn.softplus(tf.reduce_logsumexp(d_logits_fake, 1)))
                d_loss_fakes.append(d_loss_fake)

            # Part III: sup ####################
            d_sup_probs, d_sup_logits, _ = self.discriminator(self.labeled_inputs, self.K, disc_params, reuse=tf.AUTO_REUSE)

            d_loss_sup = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits_v2(logits=d_sup_logits, labels=self.labels))

            ################### total loss for semi-supervised discriminator ######################

            d_losses_semi = []
            for d_loss_fake_ in d_loss_fakes:
                d_loss_semi_ = d_loss_sup + d_loss_real * float(self.num_gen) + d_loss_fake_

                if not self.ml:
                    # bayes term: log( theta_d | alpha_d )

                    d_loss_semi_ += self.disc_prior(disc_params) + self.disc_noise(disc_params)  # 12

                d_losses_semi.append(tf.reshape(d_loss_semi_, [1]))

            d_loss_semi = tf.reduce_logsumexp(tf.concat(d_losses_semi, 0))
            self.d_losses.append(d_loss_semi)

            ################### total optimizer for semi-supervised discriminator ######################

            # after 5000 iterations
            d_opt_semi = self._get_optimizer(
                self.d_semi_learning_rate)  # what the f**k ?? have you switched the optimizer ??
            self.d_optims_semi.append(d_opt_semi.minimize(d_loss_semi, var_list=self.d_vars[di]))

            # default iterations
            d_opt_semi_adam = tf.train.AdamOptimizer(learning_rate=self.d_semi_learning_rate, beta1=0.5)
            self.d_optims_semi_adam.append(d_opt_semi_adam.minimize(d_loss_semi, var_list=self.d_vars[di]))

        ############################ build generator losses and optimizers ##########################################

        self.g_learning_rate = tf.placeholder(tf.float32, shape=[])
        self.g_vars = []
        for gi in range(self.num_gen):
            for m in range(self.num_mcmc):
                self.g_vars.append([var for var in t_vars if 'g_' in var.name and "_%04d_%04d" % (gi, m) in var.name])

        self.g_losses, self.g_optims_semi, self.g_optims_semi_adam = [], [], []

        for gi, gen_params in enumerate(self.gen_param_list):

            gi_losses = []
            for disc_params in self.disc_param_list:

                d_probs_fake, d_logits_fake, d_features_fake = self.discriminator(self.generator(self.z[:, :, gi % self.num_gen], gen_params), self.K, disc_params, reuse=tf.AUTO_REUSE)
                _, _, d_features_real = self.discriminator(self.inputs, self.K, disc_params, reuse=tf.AUTO_REUSE)

                # JT-0228: g_loss_ = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.ones_like(d_probs_fake)))
                g_loss_ = -tf.reduce_mean(tf.reduce_logsumexp(d_logits_fake, 1)) + tf.reduce_mean(tf.nn.softplus(tf.reduce_logsumexp(d_logits_fake, 1)))
                g_loss_ += tf.reduce_mean(huber_loss(d_features_real[-1], d_features_fake[-1]))  ## Huber loss is a variation of the squared loss, which is more robust to noise

                if not self.ml:

                    # return the prior_loss + noise_loss
                    g_loss_ += self.gen_prior(gen_params) + self.gen_noise(gen_params)  # 10

                gi_losses.append(tf.reshape(g_loss_, [1]))

            g_loss = tf.reduce_logsumexp(tf.concat(gi_losses, 0))
            self.g_losses.append(g_loss)

            ################### total optimizer for semi-supervised generator ######################

            g_opt = self._get_optimizer(self.g_learning_rate)
            self.g_optims_semi.append(g_opt.minimize(g_loss, var_list=self.g_vars[gi]))

            g_opt_adam = tf.train.AdamOptimizer(learning_rate=self.g_learning_rate, beta1=0.5)
            self.g_optims_semi_adam.append(g_opt_adam.minimize(g_loss, var_list=self.g_vars[gi]))

        self.gen_samplers = []  ### build samplers
        for gi, gen_params in enumerate(self.gen_param_list):
            self.gen_samplers.append(self.generator(self.z_sampler, gen_params))

        ### build vanilla supervised loss
        self.lbls = tf.placeholder(tf.float32, [self.batch_size, self.K], name='real_sup_targets')   # create a place for the variables,and then pass the real numbers
        self.S, self.S_logits = self.sup_discriminator(self.inputs, self.K)
        self.s_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.S_logits, labels=self.lbls))

        ################### total optimizer for semi-supervised discrinimator ######################

        t_vars = tf.trainable_variables()
        self.sup_vars = [var for var in t_vars if 'sup_' in var.name]
        supervised_lr = 0.05 * self.lr

        s_opt = self._get_optimizer(supervised_lr)
        self.s_optim = s_opt.minimize(self.s_loss, var_list=self.sup_vars)

        s_opt_adam = tf.train.AdamOptimizer(learning_rate=supervised_lr, beta1=0.5)  # what the f**k? is adam the SGHMC you mentioned in the work ??
        self.s_optim_adam = s_opt_adam.minimize(self.s_loss, var_list=self.sup_vars)

    def build_test_graph(self):

        self.test_inputs = tf.placeholder(tf.float32, [self.batch_size] + self.x_dim, name='real_test_images')

        self.test_d_probs, self.test_d_logits = [], []  # self.test_d_probs : 2 x (64, 10)
        for disc_params in self.disc_param_list:  # no generator, just discriminator

            test_d_probs_, test_d_logits_, _ = self.discriminator(self.test_inputs, self.K, disc_params, train=False, reuse=True)

            self.test_d_probs.append(test_d_probs_)  # test_d_probs_.shape = (64, 10)
            self.test_d_logits.append(test_d_logits_)

        # build standard purely supervised losses and optimizers
        self.test_s_probs, self.test_s_logits = self.sup_discriminator(self.test_inputs, self.K)

    def sup_discriminator(self, image, K):

        # TODO collapse this into disc
        with tf.variable_scope("sup_discriminator", reuse=tf.AUTO_REUSE) as scope:
            h0 = lrelu(conv2d(image, self.df_dim, name='sup_h0_conv'))
            h1 = lrelu(self.sup_d_batch_norm.sd_bn1(conv2d(h0, self.df_dim * 2, name='sup_h1_conv')))
            h2 = lrelu(self.sup_d_batch_norm.sd_bn2(conv2d(h1, self.df_dim * 4, name='sup_h2_conv')))
            h3 = lrelu(self.sup_d_batch_norm.sd_bn3(conv2d(h2, self.df_dim * 8, name='sup_h3_conv')))
            h4 = linear(tf.reshape(h3, [self.batch_size, -1]), K, 'sup_h3_lin')
            return tf.nn.softmax(h4), h4

    def discriminator(self, image, K, disc_params, train=True, reuse=False):

        with tf.variable_scope("discriminator", reuse=reuse) as scope:  # reuse=tf.AUTO_REUSE

            h = image
            for layer in range(len(self.disc_strides)):
                if layer == 0:
                    h = lrelu(conv2d(h, self.disc_weight_dims["d_h%i_W" % layer][-1], name='d_h%i_conv' % layer,
                                     k_h=self.disc_kernel_sizes[layer], k_w=self.disc_kernel_sizes[layer],
                                     d_h=self.disc_strides[layer], d_w=self.disc_strides[layer],
                                     w=disc_params["d_h%i_W" % layer], biases=disc_params["d_h%i_b" % layer]))

                # conv - bn - relu
                else:
                    h = lrelu(self.d_batch_norm["d_bn%i" % layer](
                        conv2d(h, self.disc_weight_dims["d_h%i_W" % layer][-1], name='d_h%i_conv' % layer,
                               k_h=self.disc_kernel_sizes[layer], k_w=self.disc_kernel_sizes[layer],
                               d_h=self.disc_strides[layer], d_w=self.disc_strides[layer],
                               w=disc_params["d_h%i_W" % layer], biases=disc_params["d_h%i_b" % layer]), train=train))

            h_end = lrelu(linear(tf.reshape(h, [self.batch_size, -1]), self.df_dim * 4, "d_h_end_lin",
                                 matrix=disc_params.d_h_end_lin_W, bias=disc_params.d_h_end_lin_b))  # for feature norm
            h_out = linear(h_end, K, 'd_h_out_lin',
                           matrix=disc_params.d_h_out_lin_W, bias=disc_params.d_h_out_lin_b)

            return tf.nn.softmax(h_out), h_out, [h_end]

    def generator(self, z, gen_params):

        with tf.variable_scope("generator", reuse=tf.AUTO_REUSE) as scope:

            h = linear(z, self.gen_weight_dims["g_h0_lin_W"][-1], 'g_h0_lin',
                       matrix=gen_params.g_h0_lin_W, bias=gen_params.g_h0_lin_b)

            h = tf.nn.relu(self.g_batch_norm.g_bn0(h))

            h = tf.reshape(h, [self.batch_size, self.gen_output_dims["g_h0_out"][0],
                               self.gen_output_dims["g_h0_out"][1], -1])

            for layer in range(1, len(self.gen_strides) + 1):

                out_shape = [self.batch_size, self.gen_output_dims["g_h%i_out" % layer][0],
                             self.gen_output_dims["g_h%i_out" % layer][1], self.gen_weight_dims["g_h%i_W" % layer][-2]]

                h = deconv2d(h,
                             out_shape,
                             k_h=self.gen_kernel_sizes[layer - 1], k_w=self.gen_kernel_sizes[layer - 1],
                             d_h=self.gen_strides[layer - 1], d_w=self.gen_strides[layer - 1],
                             name='g_h%i' % layer,
                             w=gen_params["g_h%i_W" % layer], biases=gen_params["g_h%i_b" % layer])
                if layer < len(self.gen_strides):
                    h = tf.nn.relu(self.g_batch_norm["g_bn%i" % layer](h))

            return tf.nn.tanh(h)

    def gen_prior(self, gen_params):

        with tf.variable_scope("generator") as scope:
            prior_loss = 0.0
            for var in gen_params.values():
                nn = tf.divide(var, self.prior_std)
                prior_loss += tf.reduce_mean(tf.multiply(nn, nn))

        prior_loss /= self.dataset_size
        return prior_loss

    def gen_noise(self, gen_params):  # noise_ : gaussian distribution
        with tf.variable_scope("generator") as scope:
            noise_loss = 0.0
            for name, var in gen_params.items():  # .iteritems():

                noise_ = tf.distributions.Normal(loc=0., scale=self.noise_std * tf.ones(var.get_shape()))  # tf.contrib.distributions.Normal(mu=0., sigma=self.noise_std*tf.ones(var.get_shape()))
                noise_loss += tf.reduce_sum(var * noise_.sample())

        noise_loss /= self.dataset_size
        return noise_loss

    def disc_prior(self, disc_params):

        with tf.variable_scope("discriminator") as scope:
            prior_loss = 0.0
            for var in disc_params.values():

                # print("var_disc_prior shape = ", var.get_shape(), var)
                # (5, 5, 3, 96) <tf.Variable 'discriminator/d_h0_W_0000_0000:0' shape=(5, 5, 3, 96) dtype=float32_ref>

                nn = tf.divide(var, self.prior_std)
                prior_loss += tf.reduce_mean(tf.multiply(nn, nn))

        prior_loss /= self.dataset_size
        return prior_loss

    def disc_noise(self, disc_params):

        with tf.variable_scope("discriminator") as scope:
            noise_loss = 0.0
            for var in disc_params.values():
                noise_ = tf.distributions.Normal(loc=0., scale=self.noise_std * tf.ones(var.get_shape()))  # tf.contrib.distributions.Normal(mu=0., sigma=self.noise_std*tf.ones(var.get_shape()))
                noise_loss += tf.reduce_sum(var * noise_.sample())

        noise_loss /= self.dataset_size
        return noise_loss