tf.app.flags.DEFINE_string('dis_model_ckpt', None, '') tf.app.flags.DEFINE_string('gen_figure_data', None, '') # gen model tf.app.flags.DEFINE_float('kd_lamda', 0.3, '') tf.app.flags.DEFINE_float('gen_weight_decay', 0.001, 'l2 coefficient') tf.app.flags.DEFINE_float('temperature', 3.0, '') tf.app.flags.DEFINE_string('gen_model_ckpt', None, '') tf.app.flags.DEFINE_integer('num_gen_epoch', 5, '') # tch model tf.app.flags.DEFINE_float('tch_weight_decay', 0.00001, 'l2 coefficient') tf.app.flags.DEFINE_integer('embedding_size', 10, '') tf.app.flags.DEFINE_string('tch_model_ckpt', None, '') tf.app.flags.DEFINE_integer('num_tch_epoch', 5, '') flags = tf.app.flags.FLAGS train_data_size = utils.get_train_data_size(flags.dataset) valid_data_size = utils.get_valid_data_size(flags.dataset) num_batch_t = int(flags.num_epoch * train_data_size / flags.batch_size) num_batch_v = int(valid_data_size / config.valid_batch_size) eval_interval = int(train_data_size / flags.batch_size) print('tn:\t#batch=%d\nvd:\t#batch=%d\neval:\t#interval=%d' % (num_batch_t, num_batch_v, eval_interval)) def main(_): gen_t = GEN(flags, is_training=True) scope = tf.get_variable_scope() scope.reuse_variables() gen_v = GEN(flags, is_training=False) tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate)
def __init__(self, flags, is_training=True): self.is_training = is_training # None = batch_size self.image_ph = tf.placeholder(tf.float32, shape=(None, flags.feature_size)) self.text_ph = tf.placeholder(tf.int64, shape=(None, None)) self.hard_label_ph = tf.placeholder(tf.float32, shape=(None, flags.num_label)) #self.soft_label_ph = tf.placeholder(tf.float32, shape=(None, flags.num_label)) # None = batch_size * sample_size self.sample_ph = tf.placeholder(tf.int32, shape=(None, 2)) self.reward_ph = tf.placeholder(tf.float32, shape=(None, )) tch_scope = 'tch' vocab_size = utils.get_vocab_size(flags.dataset) model_scope = nets_factory.arg_scopes_map[flags.model_name] # initializer = tf.random_uniform([vocab_size, flags.embedding_size], -0.1, 0.1) with tf.variable_scope(tch_scope) as scope: """ with slim.arg_scope([slim.fully_connected], weights_regularizer=slim.l2_regularizer(flags.tch_weight_decay)): word_embedding = slim.variable('word_embedding', shape=[vocab_size, flags.embedding_size], # regularizer=slim.l2_regularizer(flags.tch_weight_decay), initializer=tf.random_uniform_initializer(-0.1, 0.1)) # word_embedding = tf.get_variable('word_embedding', initializer=initializer) text_embedding = tf.nn.embedding_lookup(word_embedding, self.text_ph) text_embedding = tf.reduce_mean(text_embedding, axis=-2) """ with slim.arg_scope( model_scope(weight_decay=flags.gen_weight_decay)): net = self.image_ph net = slim.dropout(net, flags.dropout_keep_prob, is_training=is_training) #combined_logits = tf.concat([net, text_embedding], 1) #""" self.logits = slim.fully_connected(net, flags.num_label, activation_fn=None) self.labels = tf.nn.softmax(self.logits) if not is_training: return save_dict = {} for variable in tf.trainable_variables(): if not variable.name.startswith(tch_scope): continue print('%-50s added to TCH saver' % variable.name) save_dict[variable.name] = variable self.saver = tf.train.Saver(save_dict) global_step = tf.Variable(0, trainable=False) train_data_size = utils.get_train_data_size(flags.dataset) self.learning_rate = utils.get_lr(flags, global_step, train_data_size, flags.learning_rate, flags.learning_rate_decay_factor, flags.num_epochs_per_decay, tch_scope) # pre train pre_losses = [] pre_losses.append( tf.losses.sigmoid_cross_entropy(self.hard_label_ph, self.logits)) pre_losses.extend(tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES)) self.pre_loss = tf.add_n(pre_losses, name='%s_pre_loss' % tch_scope) pre_optimizer = tf.train.AdamOptimizer(self.learning_rate) #pre_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate) self.pre_update = pre_optimizer.minimize(self.pre_loss, global_step=global_step) # kdgan train sample_logits = tf.gather_nd(self.logits, self.sample_ph) kdgan_losses = [ tf.losses.sigmoid_cross_entropy(self.reward_ph, sample_logits) ] self.kdgan_loss = tf.add_n(kdgan_losses, name='%s_kdgan_loss' % tch_scope) kdgan_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate) self.kdgan_update = kdgan_optimizer.minimize(self.kdgan_loss, global_step=global_step)
def __init__(self, flags, is_training=True): self.is_training = is_training # None = batch_size self.image_ph = tf.placeholder(tf.float32, shape=(None, flags.feature_size)) self.hard_label_ph = tf.placeholder(tf.float32, shape=(None, flags.num_label)) self.soft_label_ph = tf.placeholder(tf.float32, shape=(None, flags.num_label)) # None = batch_size * sample_size self.sample_ph = tf.placeholder(tf.int32, shape=(None, 2)) self.reward_ph = tf.placeholder(tf.float32, shape=(None, )) gen_scope = 'gen' model_scope = nets_factory.arg_scopes_map[flags.model_name] with tf.variable_scope(gen_scope) as scope: with slim.arg_scope( model_scope(weight_decay=flags.gen_weight_decay)): net = self.image_ph net = slim.dropout(net, flags.dropout_keep_prob, is_training=is_training) net = slim.fully_connected(net, flags.num_label, activation_fn=None) self.logits = net self.labels = tf.nn.softmax(self.logits) if not is_training: return save_dict = {} for variable in tf.trainable_variables(): if not variable.name.startswith(gen_scope): continue print('%-50s added to GEN saver' % variable.name) save_dict[variable.name] = variable self.saver = tf.train.Saver(save_dict) global_step = tf.Variable(0, trainable=False) train_data_size = utils.get_train_data_size(flags.dataset) self.learning_rate = utils.get_lr(flags, global_step, train_data_size, flags.learning_rate, flags.learning_rate_decay_factor, flags.num_epochs_per_decay, gen_scope) # pre train pre_losses = [] pre_losses.append( tf.losses.sigmoid_cross_entropy(self.hard_label_ph, self.logits)) pre_losses.extend(tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES)) self.pre_loss = tf.add_n(pre_losses, name='%s_pre_loss' % gen_scope) pre_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate) self.pre_update = pre_optimizer.minimize(self.pre_loss, global_step=global_step) # kd train kd_losses = self.get_kd_losses(flags) self.kd_loss = tf.add_n(kd_losses, name='%s_kd_loss' % gen_scope) kd_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate) self.kd_update = kd_optimizer.minimize(self.kd_loss, global_step=global_step) # gan train gan_losses = self.get_gan_losses(flags) gan_losses.extend(tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES)) self.gan_loss = tf.add_n(gan_losses, name='%s_gan_loss' % gen_scope) gan_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate) self.gan_update = gan_optimizer.minimize(self.gan_loss, global_step=global_step) # kdgan train kdgan_losses = self.get_kd_losses(flags) + self.get_gan_losses(flags) self.kdgan_loss = tf.add_n(kdgan_losses, name='%s_kdgan_loss' % gen_scope) kdgan_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate) self.kdgan_update = kdgan_optimizer.minimize(self.kdgan_loss, global_step=global_step)