class BDCGAN_Semi(object): def __init__(self, x_dim, z_dim, dataset_size, batch_size=64, gf_dim=64, df_dim=64, prior_std=1.0, J=1, M=1, num_classes=1, eta=2e-4, num_layers=4, alpha=0.01, lr=0.0002, optimizer='adam', wasserstein=False, ml=False, J_d=None): assert len(x_dim) == 3, "invalid image dims" c_dim = x_dim[2] self.is_grayscale = (c_dim == 1) self.optimizer = optimizer.lower() self.dataset_size = dataset_size self.batch_size = batch_size self.K = num_classes self.x_dim = x_dim self.z_dim = z_dim self.gf_dim = gf_dim self.df_dim = df_dim self.c_dim = c_dim self.lr = lr # Bayes self.prior_std = prior_std self.num_gen = J self.num_disc = J_d if J_d is not None else 1 self.num_mcmc = M self.eta = eta self.alpha = alpha # ML self.ml = ml if self.ml: assert self.num_gen == 1 and self.num_disc == 1 and self.num_mcmc == 1, "invalid settings for ML training" self.noise_std = np.sqrt(2 * self.alpha * self.eta) def get_strides(num_layers, num_pool): interval = int(math.floor(num_layers / float(num_pool))) strides = np.array([1] * num_layers) strides[0:interval * num_pool:interval] = 2 return strides self.num_pool = 4 self.max_num_dfs = 512 self.gen_strides = get_strides(num_layers, self.num_pool) self.disc_strides = self.gen_strides num_dfs = np.cumprod(np.array([self.df_dim] + list(self.disc_strides)))[:-1] num_dfs[num_dfs >= self.max_num_dfs] = self.max_num_dfs # memory self.num_dfs = list(num_dfs) self.num_gfs = self.num_dfs[::-1] self.construct_from_hypers(gen_strides=self.gen_strides, disc_strides=self.disc_strides, num_gfs=self.num_gfs, num_dfs=self.num_dfs) self.build_bgan_graph() self.build_test_graph() def construct_from_hypers(self, gen_kernel_size=5, gen_strides=[2, 2, 2, 2], disc_kernel_size=5, disc_strides=[2, 2, 2, 2], num_dfs=None, num_gfs=None): self.d_batch_norm = AttributeDict([ ("d_bn%i" % dbn_i, batch_norm(name='d_bn%i' % dbn_i)) for dbn_i in range(len(disc_strides)) ]) self.sup_d_batch_norm = AttributeDict([ ("sd_bn%i" % dbn_i, batch_norm(name='sup_d_bn%i' % dbn_i)) for dbn_i in range(5) ]) self.g_batch_norm = AttributeDict([ ("g_bn%i" % gbn_i, batch_norm(name='g_bn%i' % gbn_i)) for gbn_i in range(len(gen_strides)) ]) if num_dfs is None: num_dfs = [ self.df_dim, self.df_dim * 2, self.df_dim * 4, self.df_dim * 8 ] if num_gfs is None: num_gfs = [ self.gf_dim * 8, self.gf_dim * 4, self.gf_dim * 2, self.gf_dim ] assert len(gen_strides) == len(num_gfs), "invalid hypers!" assert len(disc_strides) == len(num_dfs), "invalid hypers!" s_h, s_w = self.x_dim[0], self.x_dim[1] ks = gen_kernel_size self.gen_output_dims = OrderedDict() self.gen_weight_dims = OrderedDict() num_gfs = num_gfs + [self.c_dim] self.gen_kernel_sizes = [ks] for layer in range(len(gen_strides))[::-1]: self.gen_output_dims["g_h%i_out" % (layer + 1)] = (s_h, s_w) assert gen_strides[layer] <= 2, "invalid stride" assert ks % 2 == 1, "invalid kernel size" self.gen_weight_dims["g_h%i_W" % (layer + 1)] = (ks, ks, num_gfs[layer + 1], num_gfs[layer]) self.gen_weight_dims["g_h%i_b" % (layer + 1)] = (num_gfs[layer + 1], ) s_h, s_w = conv_out_size(s_h, gen_strides[layer]), conv_out_size( s_w, gen_strides[layer]) ks = kernel_sizer(ks, gen_strides[layer]) self.gen_kernel_sizes.append(ks) self.gen_weight_dims.update( OrderedDict([("g_h0_lin_W", (self.z_dim, num_gfs[0] * s_h * s_w)), ("g_h0_lin_b", (num_gfs[0] * s_h * s_w, ))])) self.gen_output_dims["g_h0_out"] = (s_h, s_w) self.disc_weight_dims = OrderedDict() s_h, s_w = self.x_dim[0], self.x_dim[1] num_dfs = [self.c_dim] + num_dfs ks = disc_kernel_size self.disc_kernel_sizes = [ks] for layer in range(len(disc_strides)): assert disc_strides[layer] <= 2, "invalid stride" assert ks % 2 == 1, "invalid kernel size" self.disc_weight_dims["d_h%i_W" % layer] = (ks, ks, num_dfs[layer], num_dfs[layer + 1]) self.disc_weight_dims["d_h%i_b" % layer] = (num_dfs[layer + 1], ) s_h, s_w = conv_out_size(s_h, disc_strides[layer]), conv_out_size( s_w, disc_strides[layer]) ks = kernel_sizer(ks, disc_strides[layer]) self.disc_kernel_sizes.append(ks) self.disc_weight_dims.update( OrderedDict([("d_h_end_lin_W", (num_dfs[-1] * s_h * s_w, num_dfs[-1])), ("d_h_end_lin_b", (num_dfs[-1], )), ("d_h_out_lin_W", (num_dfs[-1], self.K)), ("d_h_out_lin_b", (self.K, ))])) for k, v in self.gen_output_dims.items(): print "%s: %s" % (k, v) print '****' for k, v in self.gen_weight_dims.items(): print "%s: %s" % (k, v) print '****' for k, v in self.disc_weight_dims.items(): print "%s: %s" % (k, v) def construct_nets(self): self.num_disc_layers = 5 self.num_gen_layers = 5 self.d_batch_norm = AttributeDict([ ("d_bn%i" % dbn_i, batch_norm(name='d_bn%i' % dbn_i)) for dbn_i in range(self.num_disc_layers) ]) self.sup_d_batch_norm = AttributeDict([ ("sd_bn%i" % dbn_i, batch_norm(name='sup_d_bn%i' % dbn_i)) for dbn_i in range(self.num_disc_layers) ]) self.g_batch_norm = AttributeDict([ ("g_bn%i" % gbn_i, batch_norm(name='g_bn%i' % gbn_i)) for gbn_i in range(self.num_gen_layers) ]) s_h, s_w = self.x_dim[0], self.x_dim[1] s_h2, s_w2 = conv_out_size(s_h, 2), conv_out_size(s_w, 2) s_h4, s_w4 = conv_out_size(s_h2, 2), conv_out_size(s_w2, 2) s_h8, s_w8 = conv_out_size(s_h4, 2), conv_out_size(s_w4, 2) s_h16, s_w16 = conv_out_size(s_h8, 2), conv_out_size(s_w8, 2) self.gen_output_dims = OrderedDict([("g_h0_out", (s_h16, s_w16)), ("g_h1_out", (s_h8, s_w8)), ("g_h2_out", (s_h4, s_w4)), ("g_h3_out", (s_h2, s_w2)), ("g_h4_out", (s_h, s_w))]) self.gen_weight_dims = OrderedDict([ ("g_h0_lin_W", (self.z_dim, self.gf_dim * 8 * s_h16 * s_w16)), ("g_h0_lin_b", (self.gf_dim * 8 * s_h16 * s_w16, )), ("g_h1_W", (5, 5, self.gf_dim * 4, self.gf_dim * 8)), ("g_h1_b", (self.gf_dim * 4, )), ("g_h2_W", (5, 5, self.gf_dim * 2, self.gf_dim * 4)), ("g_h2_b", (self.gf_dim * 2, )), ("g_h3_W", (5, 5, self.gf_dim * 1, self.gf_dim * 2)), ("g_h3_b", (self.gf_dim * 1, )), ("g_h4_W", (5, 5, self.c_dim, self.gf_dim * 1)), ("g_h4_b", (self.c_dim, )) ]) self.disc_weight_dims = OrderedDict([ ("d_h0_W", (5, 5, self.c_dim, self.df_dim)), ("d_h0_b", (self.df_dim, )), ("d_h1_W", (5, 5, self.df_dim, self.df_dim * 2)), ("d_h1_b", (self.df_dim * 2, )), ("d_h2_W", (5, 5, self.df_dim * 2, self.df_dim * 4)), ("d_h2_b", (self.df_dim * 4, )), ("d_h3_W", (5, 5, self.df_dim * 4, self.df_dim * 8)), ("d_h3_b", (self.df_dim * 8, )), ("d_h_end_lin_W", (self.df_dim * 8 * s_h16 * s_w16, self.df_dim * 4)), ("d_h_end_lin_b", (self.df_dim * 4, )), ("d_h_out_lin_W", (self.df_dim * 4, self.K)), ("d_h_out_lin_b", (self.K, )) ]) def _get_optimizer(self, lr): if self.optimizer == 'adam': return tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5) elif self.optimizer == 'sgd': return tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.5) else: raise ValueError("Optimizer must be either 'adam' or 'sgd'") def initialize_wgts(self, scope_str): if scope_str == "generator": weight_dims = self.gen_weight_dims numz = self.num_gen elif scope_str == "discriminator": weight_dims = self.disc_weight_dims numz = self.num_disc else: raise RuntimeError("invalid scope!") param_list = [] with tf.variable_scope(scope_str) as scope: for zi in xrange(numz): for m in xrange(self.num_mcmc): wgts_ = AttributeDict() for name, shape in weight_dims.iteritems(): wgts_[name] = tf.get_variable( "%s_%04d_%04d" % (name, zi, m), shape, initializer=tf.random_normal_initializer( stddev=0.02)) param_list.append(wgts_) return param_list def build_bgan_graph(self): self.inputs = tf.placeholder(tf.float32, [self.batch_size] + self.x_dim, name='real_images') self.labeled_inputs = tf.placeholder(tf.float32, [self.batch_size] + self.x_dim, name='real_images_w_labels') self.labels = tf.placeholder(tf.float32, [self.batch_size, self.K], name='real_targets') self.z = tf.placeholder(tf.float32, [self.batch_size, self.z_dim, self.num_gen], name='z') self.z_sampler = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z_sampler') # initialize generator weights self.gen_param_list = self.initialize_wgts("generator") self.disc_param_list = self.initialize_wgts("discriminator") ### build discrimitive losses and optimizers # prep optimizer args self.d_semi_learning_rate = tf.placeholder(tf.float32, shape=[]) # compile all disciminative weights t_vars = tf.trainable_variables() self.d_vars = [] for di in xrange(self.num_disc): for m in xrange(self.num_mcmc): self.d_vars.append([ var for var in t_vars if 'd_' in var.name and "_%04d_%04d" % (di, m) in var.name ]) ### build disc losses and optimizers self.d_losses, self.d_optims_semi, self.d_optims_semi_adam = [], [], [] for di, disc_params in enumerate(self.disc_param_list): d_probs, d_logits, _ = self.discriminator(self.inputs, self.K, disc_params) d_loss_real = -tf.reduce_mean(tf.reduce_logsumexp(d_logits, 1)) +\ tf.reduce_mean(tf.nn.softplus(tf.reduce_logsumexp(d_logits, 1))) d_loss_fakes = [] for gi, gen_params in enumerate(self.gen_param_list): d_probs_, d_logits_, _ = self.discriminator( self.generator(self.z[:, :, gi % self.num_gen], gen_params), self.K, disc_params) d_loss_fake_ = tf.reduce_mean( tf.nn.softplus(tf.reduce_logsumexp(d_logits_, 1))) d_loss_fakes.append(d_loss_fake_) d_sup_probs, d_sup_logits, _ = self.discriminator( self.labeled_inputs, self.K, disc_params) d_loss_sup = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=d_sup_logits, labels=self.labels)) d_losses_semi = [] for d_loss_fake_ in d_loss_fakes: d_loss_semi_ = d_loss_sup + d_loss_real * float( self.num_gen) + d_loss_fake_ if not self.ml: d_loss_semi_ += self.disc_prior( disc_params) + self.disc_noise(disc_params) d_losses_semi.append(tf.reshape(d_loss_semi_, [1])) d_loss_semi = tf.reduce_logsumexp(tf.concat(d_losses_semi, 0)) self.d_losses.append(d_loss_semi) d_opt_semi = self._get_optimizer(self.d_semi_learning_rate) self.d_optims_semi.append( d_opt_semi.minimize(d_loss_semi, var_list=self.d_vars[di])) d_opt_semi_adam = tf.train.AdamOptimizer( learning_rate=self.d_semi_learning_rate, beta1=0.5) self.d_optims_semi_adam.append( d_opt_semi_adam.minimize(d_loss_semi, var_list=self.d_vars[di])) ### build generative losses and optimizers self.g_learning_rate = tf.placeholder(tf.float32, shape=[]) self.g_vars = [] for gi in xrange(self.num_gen): for m in xrange(self.num_mcmc): self.g_vars.append([ var for var in t_vars if 'g_' in var.name and "_%04d_%04d" % (gi, m) in var.name ]) self.g_losses, self.g_optims_semi, self.g_optims_semi_adam = [], [], [] for gi, gen_params in enumerate(self.gen_param_list): gi_losses = [] for disc_params in self.disc_param_list: d_probs_, d_logits_, d_features_fake = self.discriminator( self.generator(self.z[:, :, gi % self.num_gen], gen_params), self.K, disc_params) _, _, d_features_real = self.discriminator( self.inputs, self.K, disc_params) g_loss_ = -tf.reduce_mean(tf.reduce_logsumexp(d_logits_, 1)) +\ tf.reduce_mean(tf.nn.softplus(tf.reduce_logsumexp(d_logits_, 1))) # not needed?! g_loss_ += tf.reduce_mean( huber_loss(d_features_real[-1], d_features_fake[-1])) if not self.ml: g_loss_ += self.gen_prior(gen_params) + self.gen_noise( gen_params) gi_losses.append(tf.reshape(g_loss_, [1])) g_loss = tf.reduce_logsumexp(tf.concat(gi_losses, 0)) self.g_losses.append(g_loss) g_opt = self._get_optimizer(self.g_learning_rate) self.g_optims_semi.append( g_opt.minimize(g_loss, var_list=self.g_vars[gi])) g_opt_adam = tf.train.AdamOptimizer( learning_rate=self.g_learning_rate, beta1=0.5) self.g_optims_semi_adam.append( g_opt_adam.minimize(g_loss, var_list=self.g_vars[gi])) ### build samplers self.gen_samplers = [] for gi, gen_params in enumerate(self.gen_param_list): self.gen_samplers.append(self.generator(self.z_sampler, gen_params)) ### build vanilla supervised loss self.lbls = tf.placeholder(tf.float32, [self.batch_size, self.K], name='real_sup_targets') self.S, self.S_logits = self.sup_discriminator(self.inputs, self.K) self.s_loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=self.S_logits, labels=self.lbls)) t_vars = tf.trainable_variables() self.sup_vars = [var for var in t_vars if 'sup_' in var.name] supervised_lr = 0.05 * self.lr s_opt = self._get_optimizer(supervised_lr) self.s_optim = s_opt.minimize(self.s_loss, var_list=self.sup_vars) s_opt_adam = tf.train.AdamOptimizer(learning_rate=supervised_lr, beta1=0.5) self.s_optim_adam = s_opt_adam.minimize(self.s_loss, var_list=self.sup_vars) def build_test_graph(self): self.test_inputs = tf.placeholder(tf.float32, [self.batch_size] + self.x_dim, name='real_test_images') self.test_d_probs, self.test_d_logits = [], [] for disc_params in self.disc_param_list: test_d_probs_, test_d_logits_, _ = self.discriminator( self.test_inputs, self.K, disc_params, train=False) self.test_d_probs.append(test_d_probs_) self.test_d_logits.append(test_d_logits_) # build standard purely supervised losses and optimizers self.test_s_probs, self.test_s_logits = self.sup_discriminator( self.test_inputs, self.K, reuse=True) def sup_discriminator(self, image, K, reuse=False): # TODO collapse this into disc with tf.variable_scope("sup_discriminator") as scope: if reuse: scope.reuse_variables() h0 = lrelu(conv2d(image, self.df_dim, name='sup_h0_conv')) h1 = lrelu( self.sup_d_batch_norm.sd_bn1( conv2d(h0, self.df_dim * 2, name='sup_h1_conv'))) h2 = lrelu( self.sup_d_batch_norm.sd_bn2( conv2d(h1, self.df_dim * 4, name='sup_h2_conv'))) h3 = lrelu( self.sup_d_batch_norm.sd_bn3( conv2d(h2, self.df_dim * 8, name='sup_h3_conv'))) h4 = linear(tf.reshape(h3, [self.batch_size, -1]), K, 'sup_h3_lin') return tf.nn.softmax(h4), h4 def discriminator(self, image, K, disc_params, train=True): with tf.variable_scope("discriminator") as scope: h = image for layer in range(len(self.disc_strides)): if layer == 0: h = lrelu( conv2d(h, self.disc_weight_dims["d_h%i_W" % layer][-1], name='d_h%i_conv' % layer, k_h=self.disc_kernel_sizes[layer], k_w=self.disc_kernel_sizes[layer], d_h=self.disc_strides[layer], d_w=self.disc_strides[layer], w=disc_params["d_h%i_W" % layer], biases=disc_params["d_h%i_b" % layer])) else: h = lrelu(self.d_batch_norm["d_bn%i" % layer](conv2d( h, self.disc_weight_dims["d_h%i_W" % layer][-1], name='d_h%i_conv' % layer, k_h=self.disc_kernel_sizes[layer], k_w=self.disc_kernel_sizes[layer], d_h=self.disc_strides[layer], d_w=self.disc_strides[layer], w=disc_params["d_h%i_W" % layer], biases=disc_params["d_h%i_b" % layer]), train=train)) h_end = lrelu( linear(tf.reshape(h, [self.batch_size, -1]), self.df_dim * 4, "d_h_end_lin", matrix=disc_params.d_h_end_lin_W, bias=disc_params.d_h_end_lin_b)) # for feature norm h_out = linear(h_end, K, 'd_h_out_lin', matrix=disc_params.d_h_out_lin_W, bias=disc_params.d_h_out_lin_b) return tf.nn.softmax(h_out), h_out, [h_end] def generator(self, z, gen_params): with tf.variable_scope("generator") as scope: h = linear(z, self.gen_weight_dims["g_h0_lin_W"][-1], 'g_h0_lin', matrix=gen_params.g_h0_lin_W, bias=gen_params.g_h0_lin_b) h = tf.nn.relu(self.g_batch_norm.g_bn0(h)) h = tf.reshape(h, [ self.batch_size, self.gen_output_dims["g_h0_out"][0], self.gen_output_dims["g_h0_out"][1], -1 ]) for layer in range(1, len(self.gen_strides) + 1): out_shape = [ self.batch_size, self.gen_output_dims["g_h%i_out" % layer][0], self.gen_output_dims["g_h%i_out" % layer][1], self.gen_weight_dims["g_h%i_W" % layer][-2] ] h = deconv2d(h, out_shape, k_h=self.gen_kernel_sizes[layer - 1], k_w=self.gen_kernel_sizes[layer - 1], d_h=self.gen_strides[layer - 1], d_w=self.gen_strides[layer - 1], name='g_h%i' % layer, w=gen_params["g_h%i_W" % layer], biases=gen_params["g_h%i_b" % layer]) if layer < len(self.gen_strides): h = tf.nn.relu(self.g_batch_norm["g_bn%i" % layer](h)) return tf.nn.tanh(h) def gen_prior(self, gen_params): with tf.variable_scope("generator") as scope: prior_loss = 0.0 for var in gen_params.values(): nn = tf.divide(var, self.prior_std) prior_loss += tf.reduce_mean(tf.multiply(nn, nn)) prior_loss /= self.dataset_size return prior_loss def gen_noise(self, gen_params): with tf.variable_scope("generator") as scope: noise_loss = 0.0 for name, var in gen_params.iteritems(): noise_ = tf.contrib.distributions.Normal( mu=0., sigma=self.noise_std * tf.ones(var.get_shape())) noise_loss += tf.reduce_sum(var * noise_.sample()) noise_loss /= self.dataset_size return noise_loss def disc_prior(self, disc_params): with tf.variable_scope("discriminator") as scope: prior_loss = 0.0 for var in disc_params.values(): nn = tf.divide(var, self.prior_std) prior_loss += tf.reduce_mean(tf.multiply(nn, nn)) prior_loss /= self.dataset_size return prior_loss def disc_noise(self, disc_params): with tf.variable_scope("discriminator") as scope: noise_loss = 0.0 for var in disc_params.values(): noise_ = tf.contrib.distributions.Normal( mu=0., sigma=self.noise_std * tf.ones(var.get_shape())) noise_loss += tf.reduce_sum(var * noise_.sample()) noise_loss /= self.dataset_size return noise_loss
class BDCGAN_Semi_3d(object): def __init__(self, x_dim, z_dim, dataset_size, batch_size=64, gf_dim=64, df_dim=64, prior_std=1.0, J=1, M=1, num_classes=1, eta=1, num_layers=4, alpha=0.01, lr=0.0002, optimizer='adam', wasserstein=False, ml=False, J_d=None): # eta=2e-4, print("ml = ", ml) self.optimizer = optimizer.lower() self.dataset_size = dataset_size self.batch_size = batch_size self.K = num_classes self.x_dim = x_dim self.z_dim = z_dim # generated sample's dim self.gf_dim = gf_dim # ?? what is df_dim = 64 ? self.df_dim = df_dim self.c_dim = x_dim[3] # x_dim = [x, y, z, c] self.is_grayscale = (self.c_dim == 1) self.lr = lr # Bayes self.prior_std = prior_std self.num_gen = J # what is num_gen ?? self.num_disc = J_d if J_d is not None else 1 self.num_mcmc = M self.eta = eta # not required in variational inference and MC dropout self.alpha = alpha # not required in variational inference and MC dropout # ML self.ml = ml if self.ml: assert self.num_gen == 1 and self.num_disc == 1 and self.num_mcmc == 1, "invalid settings for ML training" self.noise_std = 10 # np.sqrt(2 * self.alpha * self.eta)\ def get_strides(num_layers, num_pool): interval = int(math.floor(num_layers / float(num_pool))) strides = np.array([1] * num_layers) strides[0:interval * num_pool:interval] = 2 return strides self.num_pool = 4 self.max_num_dfs = 1024 # default - 512 self.gen_strides = get_strides(num_layers, self.num_pool) self.disc_strides = self.gen_strides num_dfs = np.cumprod(np.array([self.df_dim] + list(self.disc_strides)))[:-1] num_dfs[num_dfs >= self.max_num_dfs] = self.max_num_dfs # memory self.num_dfs = list(num_dfs) self.num_gfs = self.num_dfs[::-1] self.construct_from_hypers(gen_strides=self.gen_strides, disc_strides=self.disc_strides, num_gfs=self.num_gfs, num_dfs=self.num_dfs) self.build_bgan_graph() self.build_test_graph() def construct_from_hypers(self, gen_kernel_size=5, gen_strides=[2, 2, 2, 2], disc_kernel_size=5, disc_strides=[2, 2, 2, 2], num_dfs=None, num_gfs=None): self.d_batch_norm = AttributeDict( [("d_bn%i" % dbn_i, batch_norm(name='d_bn%i' % dbn_i)) for dbn_i in range(len(disc_strides))]) self.sup_d_batch_norm = AttributeDict( [("sd_bn%i" % dbn_i, batch_norm(name='sup_d_bn%i' % dbn_i)) for dbn_i in range(5)]) self.g_batch_norm = AttributeDict( [("g_bn%i" % gbn_i, batch_norm(name='g_bn%i' % gbn_i)) for gbn_i in range(len(gen_strides))]) if num_dfs is None: num_dfs = [self.df_dim, self.df_dim * 2, self.df_dim * 4, self.df_dim * 8] if num_gfs is None: num_gfs = [self.gf_dim * 8, self.gf_dim * 4, self.gf_dim * 2, self.gf_dim] assert len(gen_strides) == len(num_gfs), "invalid hypers!" assert len(disc_strides) == len(num_dfs), "invalid hypers!" s_h, s_w = self.x_dim[0], self.x_dim[1] ks = gen_kernel_size self.gen_output_dims = OrderedDict() self.gen_weight_dims = OrderedDict() num_gfs = num_gfs + [self.c_dim] self.gen_kernel_sizes = [ks] for layer in range(len(gen_strides))[::-1]: self.gen_output_dims["g_h%i_out" % (layer + 1)] = (s_h, s_w) assert gen_strides[layer] <= 2, "invalid stride" assert ks % 2 == 1, "invalid kernel size" self.gen_weight_dims["g_h%i_W" % (layer + 1)] = (ks, ks, num_gfs[layer + 1], num_gfs[layer]) self.gen_weight_dims["g_h%i_b" % (layer + 1)] = (num_gfs[layer + 1],) s_h, s_w = conv_out_size(s_h, gen_strides[layer]), conv_out_size(s_w, gen_strides[layer]) ks = kernel_sizer(ks, gen_strides[layer]) self.gen_kernel_sizes.append(ks) self.gen_weight_dims.update(OrderedDict([("g_h0_lin_W", (self.z_dim, num_gfs[0] * s_h * s_w)), ("g_h0_lin_b", (num_gfs[0] * s_h * s_w,))])) self.gen_output_dims["g_h0_out"] = (s_h, s_w) self.disc_weight_dims = OrderedDict() s_h, s_w = self.x_dim[0], self.x_dim[1] num_dfs = [self.c_dim] + num_dfs ks = disc_kernel_size self.disc_kernel_sizes = [ks] for layer in range(len(disc_strides)): assert disc_strides[layer] <= 2, "invalid stride" assert ks % 2 == 1, "invalid kernel size" self.disc_weight_dims["d_h%i_W" % layer] = (ks, ks, num_dfs[layer], num_dfs[layer + 1]) self.disc_weight_dims["d_h%i_b" % layer] = (num_dfs[layer + 1],) s_h, s_w = conv_out_size(s_h, disc_strides[layer]), conv_out_size(s_w, disc_strides[layer]) ks = kernel_sizer(ks, disc_strides[layer]) self.disc_kernel_sizes.append(ks) self.disc_weight_dims.update(OrderedDict([("d_h_end_lin_W", (num_dfs[-1] * s_h * s_w, num_dfs[-1])), ("d_h_end_lin_b", (num_dfs[-1],)), ("d_h_out_lin_W", (num_dfs[-1], self.K)), ("d_h_out_lin_b", (self.K,))])) for k, v in self.gen_output_dims.items(): print("gen_output_dims - %s: %s" % (k, v)) print('****') for k, v in self.gen_weight_dims.items(): print("gen_weight_dims - %s: %s" % (k, v)) print('****') for k, v in self.disc_weight_dims.items(): print("dics_weight_dims - %s: %s" % (k, v)) def construct_nets(self): self.num_disc_layers = 5 self.num_gen_layers = 5 self.d_batch_norm = AttributeDict( [("d_bn%i" % dbn_i, batch_norm(name='d_bn%i' % dbn_i)) for dbn_i in range(self.num_disc_layers)]) self.sup_d_batch_norm = AttributeDict( [("sd_bn%i" % dbn_i, batch_norm(name='sup_d_bn%i' % dbn_i)) for dbn_i in range(self.num_disc_layers)]) self.g_batch_norm = AttributeDict( [("g_bn%i" % gbn_i, batch_norm(name='g_bn%i' % gbn_i)) for gbn_i in range(self.num_gen_layers)]) s_h, s_w = self.x_dim[0], self.x_dim[1] s_h2, s_w2 = conv_out_size(s_h, 2), conv_out_size(s_w, 2) s_h4, s_w4 = conv_out_size(s_h2, 2), conv_out_size(s_w2, 2) s_h8, s_w8 = conv_out_size(s_h4, 2), conv_out_size(s_w4, 2) s_h16, s_w16 = conv_out_size(s_h8, 2), conv_out_size(s_w8, 2) self.gen_output_dims = OrderedDict([("g_h0_out", (s_h16, s_w16)), ("g_h1_out", (s_h8, s_w8)), ("g_h2_out", (s_h4, s_w4)), ("g_h3_out", (s_h2, s_w2)), ("g_h4_out", (s_h, s_w))]) self.gen_weight_dims = OrderedDict([("g_h0_lin_W", (self.z_dim, self.gf_dim * 8 * s_h16 * s_w16)), ("g_h0_lin_b", (self.gf_dim * 8 * s_h16 * s_w16,)), ("g_h1_W", (5, 5, self.gf_dim * 4, self.gf_dim * 8)), ("g_h1_b", (self.gf_dim * 4,)), ("g_h2_W", (5, 5, self.gf_dim * 2, self.gf_dim * 4)), ("g_h2_b", (self.gf_dim * 2,)), ("g_h3_W", (5, 5, self.gf_dim * 1, self.gf_dim * 2)), ("g_h3_b", (self.gf_dim * 1,)), ("g_h4_W", (5, 5, self.c_dim, self.gf_dim * 1)), ("g_h4_b", (self.c_dim,))]) self.disc_weight_dims = OrderedDict([("d_h0_W", (5, 5, self.c_dim, self.df_dim)), ("d_h0_b", (self.df_dim,)), ("d_h1_W", (5, 5, self.df_dim, self.df_dim * 2)), ("d_h1_b", (self.df_dim * 2,)), ("d_h2_W", (5, 5, self.df_dim * 2, self.df_dim * 4)), ("d_h2_b", (self.df_dim * 4,)), ("d_h3_W", (5, 5, self.df_dim * 4, self.df_dim * 8)), ("d_h3_b", (self.df_dim * 8,)), ("d_h_end_lin_W", (self.df_dim * 8 * s_h16 * s_w16, self.df_dim * 4)), ("d_h_end_lin_b", (self.df_dim * 4,)), ("d_h_out_lin_W", (self.df_dim * 4, self.K)), ("d_h_out_lin_b", (self.K,))]) def _get_optimizer(self, lr): if self.optimizer == 'adam': return tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5) elif self.optimizer == 'sgd': return tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.5) else: raise ValueError("Optimizer must be either 'adam' or 'sgd'") def initialize_wgts(self, scope_str): if scope_str == "generator": weight_dims = self.gen_weight_dims numz = self.num_gen elif scope_str == "discriminator": weight_dims = self.disc_weight_dims numz = self.num_disc else: raise RuntimeError("invalid scope!") param_list = [] with tf.variable_scope(scope_str) as scope: # iterated J (numz / num_gen) x num_mcmc = 20 for zi in range(numz): # numz: num_gen / num_disc for m in range(self.num_mcmc): wgts_ = AttributeDict() for name, shape in weight_dims.items(): wgts_[name] = tf.get_variable("%s_%04d_%04d" % (name, zi, m), shape, initializer=tf.random_normal_initializer(stddev=0.02)) param_list.append(wgts_) return param_list def build_bgan_graph(self): # unsupervised images from data distribution self.inputs = tf.placeholder(tf.float32, [self.batch_size] + self.x_dim, name='real_images') # for discrinimator: from supervised batch images self.labeled_inputs = tf.placeholder(tf.float32, [self.batch_size] + self.x_dim, name='real_images_w_labels') self.labels = tf.placeholder(tf.float32, [self.batch_size, self.K], name='real_targets') # for generator self.z = tf.placeholder(tf.float32, [self.batch_size, self.z_dim, self.num_gen], name='z') # [64, 100, 10] self.z_sampler = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z_sampler') # initialize generator weights self.gen_param_list = self.initialize_wgts("generator") # num_gen * num_mcmc - list self.disc_param_list = self.initialize_wgts("discriminator") # num_disc * num_mcmc ############################ build discrimitive losses and optimizers ########################################## self.d_semi_learning_rate = tf.placeholder(tf.float32, shape=[]) t_vars = tf.trainable_variables() # compile all disciminative weights # returns a list of trainable variables self.d_vars = [] for di in range(self.num_disc): for m in range(self.num_mcmc): self.d_vars.append([var for var in t_vars if 'd_' in var.name and "_%04d_%04d" % (di, m) in var.name]) self.d_losses, self.d_optims_semi, self.d_optims_semi_adam = [], [], [] ### self.d_optims_semi is user specified optimizer for di, disc_params in enumerate(self.disc_param_list): # with len(disc_param_list) > 1, the first discrinimator could be reuse = False, however, the second should use the variables # Part I: real #################### # d_probs = softmax(d_logits), d_logits = linear(pre-layer) d_probs_real, d_logits_real, _ = self.discriminator(self.inputs, self.K, disc_params, reuse=tf.AUTO_REUSE) # JT-0228: d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, labels=tf.ones_like(d_probs_real))) d_loss_real = - tf.reduce_mean(tf.reduce_logsumexp(d_logits_real, 1)) \ + tf.reduce_mean(tf.nn.softplus(tf.reduce_logsumexp(d_logits_real, 1))) # Part II: fake #################### d_loss_fakes = [] for gi, gen_params in enumerate(self.gen_param_list): # iterate num_gen * num_mcmc times d_probs_fake, d_logits_fake, _ = self.discriminator( self.generator(self.z[:, :, gi % self.num_gen], gen_params), self.K, disc_params, reuse=True) # JT-0228: d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_probs_fake))) d_loss_fake = tf.reduce_mean(tf.nn.softplus(tf.reduce_logsumexp(d_logits_fake, 1))) d_loss_fakes.append(d_loss_fake) # Part III: sup #################### d_sup_probs, d_sup_logits, _ = self.discriminator(self.labeled_inputs, self.K, disc_params, reuse=tf.AUTO_REUSE) d_loss_sup = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(logits=d_sup_logits, labels=self.labels)) ################### total loss for semi-supervised discriminator ###################### d_losses_semi = [] for d_loss_fake_ in d_loss_fakes: d_loss_semi_ = d_loss_sup + d_loss_real * float(self.num_gen) + d_loss_fake_ if not self.ml: # bayes term: log( theta_d | alpha_d ) d_loss_semi_ += self.disc_prior(disc_params) + self.disc_noise(disc_params) # 12 d_losses_semi.append(tf.reshape(d_loss_semi_, [1])) d_loss_semi = tf.reduce_logsumexp(tf.concat(d_losses_semi, 0)) self.d_losses.append(d_loss_semi) ################### total optimizer for semi-supervised discriminator ###################### # after 5000 iterations d_opt_semi = self._get_optimizer( self.d_semi_learning_rate) # what the f**k ?? have you switched the optimizer ?? self.d_optims_semi.append(d_opt_semi.minimize(d_loss_semi, var_list=self.d_vars[di])) # default iterations d_opt_semi_adam = tf.train.AdamOptimizer(learning_rate=self.d_semi_learning_rate, beta1=0.5) self.d_optims_semi_adam.append(d_opt_semi_adam.minimize(d_loss_semi, var_list=self.d_vars[di])) ############################ build generator losses and optimizers ########################################## self.g_learning_rate = tf.placeholder(tf.float32, shape=[]) self.g_vars = [] for gi in range(self.num_gen): for m in range(self.num_mcmc): self.g_vars.append([var for var in t_vars if 'g_' in var.name and "_%04d_%04d" % (gi, m) in var.name]) self.g_losses, self.g_optims_semi, self.g_optims_semi_adam = [], [], [] for gi, gen_params in enumerate(self.gen_param_list): gi_losses = [] for disc_params in self.disc_param_list: d_probs_fake, d_logits_fake, d_features_fake = self.discriminator(self.generator(self.z[:, :, gi % self.num_gen], gen_params), self.K, disc_params, reuse=tf.AUTO_REUSE) _, _, d_features_real = self.discriminator(self.inputs, self.K, disc_params, reuse=tf.AUTO_REUSE) # JT-0228: g_loss_ = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.ones_like(d_probs_fake))) g_loss_ = -tf.reduce_mean(tf.reduce_logsumexp(d_logits_fake, 1)) + tf.reduce_mean(tf.nn.softplus(tf.reduce_logsumexp(d_logits_fake, 1))) g_loss_ += tf.reduce_mean(huber_loss(d_features_real[-1], d_features_fake[-1])) ## Huber loss is a variation of the squared loss, which is more robust to noise if not self.ml: # return the prior_loss + noise_loss g_loss_ += self.gen_prior(gen_params) + self.gen_noise(gen_params) # 10 gi_losses.append(tf.reshape(g_loss_, [1])) g_loss = tf.reduce_logsumexp(tf.concat(gi_losses, 0)) self.g_losses.append(g_loss) ################### total optimizer for semi-supervised generator ###################### g_opt = self._get_optimizer(self.g_learning_rate) self.g_optims_semi.append(g_opt.minimize(g_loss, var_list=self.g_vars[gi])) g_opt_adam = tf.train.AdamOptimizer(learning_rate=self.g_learning_rate, beta1=0.5) self.g_optims_semi_adam.append(g_opt_adam.minimize(g_loss, var_list=self.g_vars[gi])) self.gen_samplers = [] ### build samplers for gi, gen_params in enumerate(self.gen_param_list): self.gen_samplers.append(self.generator(self.z_sampler, gen_params)) ### build vanilla supervised loss self.lbls = tf.placeholder(tf.float32, [self.batch_size, self.K], name='real_sup_targets') # create a place for the variables,and then pass the real numbers self.S, self.S_logits = self.sup_discriminator(self.inputs, self.K) self.s_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.S_logits, labels=self.lbls)) ################### total optimizer for semi-supervised discrinimator ###################### t_vars = tf.trainable_variables() self.sup_vars = [var for var in t_vars if 'sup_' in var.name] supervised_lr = 0.05 * self.lr s_opt = self._get_optimizer(supervised_lr) self.s_optim = s_opt.minimize(self.s_loss, var_list=self.sup_vars) s_opt_adam = tf.train.AdamOptimizer(learning_rate=supervised_lr, beta1=0.5) # what the f**k? is adam the SGHMC you mentioned in the work ?? self.s_optim_adam = s_opt_adam.minimize(self.s_loss, var_list=self.sup_vars) def build_test_graph(self): self.test_inputs = tf.placeholder(tf.float32, [self.batch_size] + self.x_dim, name='real_test_images') self.test_d_probs, self.test_d_logits = [], [] # self.test_d_probs : 2 x (64, 10) for disc_params in self.disc_param_list: # no generator, just discriminator test_d_probs_, test_d_logits_, _ = self.discriminator(self.test_inputs, self.K, disc_params, train=False, reuse=True) self.test_d_probs.append(test_d_probs_) # test_d_probs_.shape = (64, 10) self.test_d_logits.append(test_d_logits_) # build standard purely supervised losses and optimizers self.test_s_probs, self.test_s_logits = self.sup_discriminator(self.test_inputs, self.K) def sup_discriminator(self, image, K): # TODO collapse this into disc with tf.variable_scope("sup_discriminator", reuse=tf.AUTO_REUSE) as scope: h0 = lrelu(conv2d(image, self.df_dim, name='sup_h0_conv')) h1 = lrelu(self.sup_d_batch_norm.sd_bn1(conv2d(h0, self.df_dim * 2, name='sup_h1_conv'))) h2 = lrelu(self.sup_d_batch_norm.sd_bn2(conv2d(h1, self.df_dim * 4, name='sup_h2_conv'))) h3 = lrelu(self.sup_d_batch_norm.sd_bn3(conv2d(h2, self.df_dim * 8, name='sup_h3_conv'))) h4 = linear(tf.reshape(h3, [self.batch_size, -1]), K, 'sup_h3_lin') return tf.nn.softmax(h4), h4 def discriminator(self, image, K, disc_params, train=True, reuse=False): with tf.variable_scope("discriminator", reuse=reuse) as scope: # reuse=tf.AUTO_REUSE h = image for layer in range(len(self.disc_strides)): if layer == 0: h = lrelu(conv2d(h, self.disc_weight_dims["d_h%i_W" % layer][-1], name='d_h%i_conv' % layer, k_h=self.disc_kernel_sizes[layer], k_w=self.disc_kernel_sizes[layer], d_h=self.disc_strides[layer], d_w=self.disc_strides[layer], w=disc_params["d_h%i_W" % layer], biases=disc_params["d_h%i_b" % layer])) # conv - bn - relu else: h = lrelu(self.d_batch_norm["d_bn%i" % layer]( conv2d(h, self.disc_weight_dims["d_h%i_W" % layer][-1], name='d_h%i_conv' % layer, k_h=self.disc_kernel_sizes[layer], k_w=self.disc_kernel_sizes[layer], d_h=self.disc_strides[layer], d_w=self.disc_strides[layer], w=disc_params["d_h%i_W" % layer], biases=disc_params["d_h%i_b" % layer]), train=train)) h_end = lrelu(linear(tf.reshape(h, [self.batch_size, -1]), self.df_dim * 4, "d_h_end_lin", matrix=disc_params.d_h_end_lin_W, bias=disc_params.d_h_end_lin_b)) # for feature norm h_out = linear(h_end, K, 'd_h_out_lin', matrix=disc_params.d_h_out_lin_W, bias=disc_params.d_h_out_lin_b) return tf.nn.softmax(h_out), h_out, [h_end] def generator(self, z, gen_params): with tf.variable_scope("generator", reuse=tf.AUTO_REUSE) as scope: h = linear(z, self.gen_weight_dims["g_h0_lin_W"][-1], 'g_h0_lin', matrix=gen_params.g_h0_lin_W, bias=gen_params.g_h0_lin_b) h = tf.nn.relu(self.g_batch_norm.g_bn0(h)) h = tf.reshape(h, [self.batch_size, self.gen_output_dims["g_h0_out"][0], self.gen_output_dims["g_h0_out"][1], -1]) for layer in range(1, len(self.gen_strides) + 1): out_shape = [self.batch_size, self.gen_output_dims["g_h%i_out" % layer][0], self.gen_output_dims["g_h%i_out" % layer][1], self.gen_weight_dims["g_h%i_W" % layer][-2]] h = deconv2d(h, out_shape, k_h=self.gen_kernel_sizes[layer - 1], k_w=self.gen_kernel_sizes[layer - 1], d_h=self.gen_strides[layer - 1], d_w=self.gen_strides[layer - 1], name='g_h%i' % layer, w=gen_params["g_h%i_W" % layer], biases=gen_params["g_h%i_b" % layer]) if layer < len(self.gen_strides): h = tf.nn.relu(self.g_batch_norm["g_bn%i" % layer](h)) return tf.nn.tanh(h) def gen_prior(self, gen_params): with tf.variable_scope("generator") as scope: prior_loss = 0.0 for var in gen_params.values(): nn = tf.divide(var, self.prior_std) prior_loss += tf.reduce_mean(tf.multiply(nn, nn)) prior_loss /= self.dataset_size return prior_loss def gen_noise(self, gen_params): # noise_ : gaussian distribution with tf.variable_scope("generator") as scope: noise_loss = 0.0 for name, var in gen_params.items(): # .iteritems(): noise_ = tf.distributions.Normal(loc=0., scale=self.noise_std * tf.ones(var.get_shape())) # tf.contrib.distributions.Normal(mu=0., sigma=self.noise_std*tf.ones(var.get_shape())) noise_loss += tf.reduce_sum(var * noise_.sample()) noise_loss /= self.dataset_size return noise_loss def disc_prior(self, disc_params): with tf.variable_scope("discriminator") as scope: prior_loss = 0.0 for var in disc_params.values(): # print("var_disc_prior shape = ", var.get_shape(), var) # (5, 5, 3, 96) <tf.Variable 'discriminator/d_h0_W_0000_0000:0' shape=(5, 5, 3, 96) dtype=float32_ref> nn = tf.divide(var, self.prior_std) prior_loss += tf.reduce_mean(tf.multiply(nn, nn)) prior_loss /= self.dataset_size return prior_loss def disc_noise(self, disc_params): with tf.variable_scope("discriminator") as scope: noise_loss = 0.0 for var in disc_params.values(): noise_ = tf.distributions.Normal(loc=0., scale=self.noise_std * tf.ones(var.get_shape())) # tf.contrib.distributions.Normal(mu=0., sigma=self.noise_std*tf.ones(var.get_shape())) noise_loss += tf.reduce_sum(var * noise_.sample()) noise_loss /= self.dataset_size return noise_loss