Esempio n. 1
0
tf.app.flags.DEFINE_string('dis_model_ckpt', None, '')
tf.app.flags.DEFINE_string('gen_figure_data', None, '')
# gen model
tf.app.flags.DEFINE_float('kd_lamda', 0.3, '')
tf.app.flags.DEFINE_float('gen_weight_decay', 0.001, 'l2 coefficient')
tf.app.flags.DEFINE_float('temperature', 3.0, '')
tf.app.flags.DEFINE_string('gen_model_ckpt', None, '')
tf.app.flags.DEFINE_integer('num_gen_epoch', 5, '')
# tch model
tf.app.flags.DEFINE_float('tch_weight_decay', 0.00001, 'l2 coefficient')
tf.app.flags.DEFINE_integer('embedding_size', 10, '')
tf.app.flags.DEFINE_string('tch_model_ckpt', None, '')
tf.app.flags.DEFINE_integer('num_tch_epoch', 5, '')
flags = tf.app.flags.FLAGS

train_data_size = utils.get_train_data_size(flags.dataset)
valid_data_size = utils.get_valid_data_size(flags.dataset)
num_batch_t = int(flags.num_epoch * train_data_size / flags.batch_size)
num_batch_v = int(valid_data_size / config.valid_batch_size)
eval_interval = int(train_data_size / flags.batch_size)
print('tn:\t#batch=%d\nvd:\t#batch=%d\neval:\t#interval=%d' %
      (num_batch_t, num_batch_v, eval_interval))


def main(_):
    gen_t = GEN(flags, is_training=True)
    scope = tf.get_variable_scope()
    scope.reuse_variables()
    gen_v = GEN(flags, is_training=False)

    tf.summary.scalar(gen_t.learning_rate.name, gen_t.learning_rate)
Esempio n. 2
0
    def __init__(self, flags, is_training=True):
        self.is_training = is_training

        # None = batch_size
        self.image_ph = tf.placeholder(tf.float32,
                                       shape=(None, flags.feature_size))
        self.text_ph = tf.placeholder(tf.int64, shape=(None, None))
        self.hard_label_ph = tf.placeholder(tf.float32,
                                            shape=(None, flags.num_label))
        #self.soft_label_ph = tf.placeholder(tf.float32, shape=(None, flags.num_label))

        # None = batch_size * sample_size
        self.sample_ph = tf.placeholder(tf.int32, shape=(None, 2))
        self.reward_ph = tf.placeholder(tf.float32, shape=(None, ))

        tch_scope = 'tch'
        vocab_size = utils.get_vocab_size(flags.dataset)
        model_scope = nets_factory.arg_scopes_map[flags.model_name]
        # initializer = tf.random_uniform([vocab_size, flags.embedding_size], -0.1, 0.1)
        with tf.variable_scope(tch_scope) as scope:
            """
      with slim.arg_scope([slim.fully_connected],
          weights_regularizer=slim.l2_regularizer(flags.tch_weight_decay)):
        word_embedding = slim.variable('word_embedding',
            shape=[vocab_size, flags.embedding_size],
            # regularizer=slim.l2_regularizer(flags.tch_weight_decay),
            initializer=tf.random_uniform_initializer(-0.1, 0.1))
        # word_embedding = tf.get_variable('word_embedding', initializer=initializer)
        text_embedding = tf.nn.embedding_lookup(word_embedding, self.text_ph)
        text_embedding = tf.reduce_mean(text_embedding, axis=-2)
      """
            with slim.arg_scope(
                    model_scope(weight_decay=flags.gen_weight_decay)):
                net = self.image_ph
                net = slim.dropout(net,
                                   flags.dropout_keep_prob,
                                   is_training=is_training)

                #combined_logits = tf.concat([net, text_embedding], 1)
                #"""
                self.logits = slim.fully_connected(net,
                                                   flags.num_label,
                                                   activation_fn=None)

        self.labels = tf.nn.softmax(self.logits)

        if not is_training:
            return

        save_dict = {}
        for variable in tf.trainable_variables():
            if not variable.name.startswith(tch_scope):
                continue
            print('%-50s added to TCH saver' % variable.name)
            save_dict[variable.name] = variable
        self.saver = tf.train.Saver(save_dict)

        global_step = tf.Variable(0, trainable=False)
        train_data_size = utils.get_train_data_size(flags.dataset)
        self.learning_rate = utils.get_lr(flags, global_step, train_data_size,
                                          flags.learning_rate,
                                          flags.learning_rate_decay_factor,
                                          flags.num_epochs_per_decay,
                                          tch_scope)

        # pre train
        pre_losses = []
        pre_losses.append(
            tf.losses.sigmoid_cross_entropy(self.hard_label_ph, self.logits))
        pre_losses.extend(tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES))
        self.pre_loss = tf.add_n(pre_losses, name='%s_pre_loss' % tch_scope)
        pre_optimizer = tf.train.AdamOptimizer(self.learning_rate)
        #pre_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        self.pre_update = pre_optimizer.minimize(self.pre_loss,
                                                 global_step=global_step)

        # kdgan train
        sample_logits = tf.gather_nd(self.logits, self.sample_ph)
        kdgan_losses = [
            tf.losses.sigmoid_cross_entropy(self.reward_ph, sample_logits)
        ]
        self.kdgan_loss = tf.add_n(kdgan_losses,
                                   name='%s_kdgan_loss' % tch_scope)
        kdgan_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        self.kdgan_update = kdgan_optimizer.minimize(self.kdgan_loss,
                                                     global_step=global_step)
Esempio n. 3
0
    def __init__(self, flags, is_training=True):
        self.is_training = is_training

        # None = batch_size
        self.image_ph = tf.placeholder(tf.float32,
                                       shape=(None, flags.feature_size))
        self.hard_label_ph = tf.placeholder(tf.float32,
                                            shape=(None, flags.num_label))
        self.soft_label_ph = tf.placeholder(tf.float32,
                                            shape=(None, flags.num_label))

        # None = batch_size * sample_size
        self.sample_ph = tf.placeholder(tf.int32, shape=(None, 2))
        self.reward_ph = tf.placeholder(tf.float32, shape=(None, ))

        gen_scope = 'gen'
        model_scope = nets_factory.arg_scopes_map[flags.model_name]
        with tf.variable_scope(gen_scope) as scope:
            with slim.arg_scope(
                    model_scope(weight_decay=flags.gen_weight_decay)):
                net = self.image_ph
                net = slim.dropout(net,
                                   flags.dropout_keep_prob,
                                   is_training=is_training)
                net = slim.fully_connected(net,
                                           flags.num_label,
                                           activation_fn=None)
                self.logits = net

        self.labels = tf.nn.softmax(self.logits)

        if not is_training:
            return

        save_dict = {}
        for variable in tf.trainable_variables():
            if not variable.name.startswith(gen_scope):
                continue
            print('%-50s added to GEN saver' % variable.name)
            save_dict[variable.name] = variable
        self.saver = tf.train.Saver(save_dict)

        global_step = tf.Variable(0, trainable=False)
        train_data_size = utils.get_train_data_size(flags.dataset)
        self.learning_rate = utils.get_lr(flags, global_step, train_data_size,
                                          flags.learning_rate,
                                          flags.learning_rate_decay_factor,
                                          flags.num_epochs_per_decay,
                                          gen_scope)

        # pre train
        pre_losses = []
        pre_losses.append(
            tf.losses.sigmoid_cross_entropy(self.hard_label_ph, self.logits))
        pre_losses.extend(tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES))
        self.pre_loss = tf.add_n(pre_losses, name='%s_pre_loss' % gen_scope)
        pre_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        self.pre_update = pre_optimizer.minimize(self.pre_loss,
                                                 global_step=global_step)

        # kd train
        kd_losses = self.get_kd_losses(flags)
        self.kd_loss = tf.add_n(kd_losses, name='%s_kd_loss' % gen_scope)
        kd_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        self.kd_update = kd_optimizer.minimize(self.kd_loss,
                                               global_step=global_step)

        # gan train
        gan_losses = self.get_gan_losses(flags)
        gan_losses.extend(tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES))
        self.gan_loss = tf.add_n(gan_losses, name='%s_gan_loss' % gen_scope)
        gan_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        self.gan_update = gan_optimizer.minimize(self.gan_loss,
                                                 global_step=global_step)

        # kdgan train
        kdgan_losses = self.get_kd_losses(flags) + self.get_gan_losses(flags)
        self.kdgan_loss = tf.add_n(kdgan_losses,
                                   name='%s_kdgan_loss' % gen_scope)
        kdgan_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        self.kdgan_update = kdgan_optimizer.minimize(self.kdgan_loss,
                                                     global_step=global_step)