Esempio n. 1
0
    def __init__(self, flags, is_training=True):
        self.is_training = is_training

        # None = batch_size
        self.image_ph = tf.placeholder(tf.float32,
                                       shape=(None, flags.feature_size))
        self.text_ph = tf.placeholder(tf.int64, shape=(None, None))
        self.hard_label_ph = tf.placeholder(tf.float32,
                                            shape=(None, flags.num_label))

        # None = batch_size * sample_size
        self.sample_ph = tf.placeholder(tf.int32, shape=(None, 2))
        self.reward_ph = tf.placeholder(tf.float32, shape=(None, ))

        self.tch_scope = tch_scope = 'tch'
        model_scope = nets_factory.arg_scopes_map[flags.image_model]
        vocab_size = utils.get_vocab_size(flags.dataset)
        with tf.variable_scope(tch_scope) as scope:
            with slim.arg_scope(
                    model_scope(weight_decay=flags.image_weight_decay)):
                iembed = self.image_ph
                iembed = slim.dropout(iembed,
                                      flags.image_keep_prob,
                                      is_training=is_training)

            with slim.arg_scope([slim.fully_connected],
                                weights_regularizer=slim.l2_regularizer(
                                    flags.text_weight_decay)):
                wembed = slim.variable(
                    'wembed',
                    shape=[vocab_size, flags.embedding_size],
                    initializer=tf.random_uniform_initializer(-0.1, 0.1))
                tembed = tf.nn.embedding_lookup(wembed, self.text_ph)
                tembed = tf.reduce_mean(tembed, axis=-2)

            with slim.arg_scope([slim.fully_connected],
                                weights_regularizer=slim.l2_regularizer(
                                    flags.tch_weight_decay),
                                biases_initializer=tf.zeros_initializer()):
                # cembed = tf.concat([tembed], 1)
                cembed = tf.concat([iembed, tembed], 1)
                self.logits = slim.fully_connected(cembed,
                                                   flags.num_label,
                                                   activation_fn=None)

            self.labels = tf.nn.softmax(self.logits)

            if not is_training:
                return

            save_dict = {}
            for variable in tf.trainable_variables():
                if not variable.name.startswith(tch_scope):
                    continue
                print('%-50s added to TCH saver' % variable.name)
                save_dict[variable.name] = variable
            self.saver = tf.train.Saver(save_dict)

            self.global_step = global_step = tf.Variable(0, trainable=False)
            tn_size = utils.get_tn_size(flags.dataset)
            learning_rate = flags.tch_learning_rate
            self.learning_rate = utils.get_lr(flags, tn_size, global_step,
                                              learning_rate, tch_scope)

            # pre train
            pre_losses = self.get_pre_losses()
            self.pre_loss = tf.add_n(pre_losses,
                                     name='%s_pre_loss' % tch_scope)
            pre_losses.extend(self.get_regularization_losses())
            print('#pre_losses wt regularization=%d' % (len(pre_losses)))
            pre_optimizer = utils.get_opt(flags, self.learning_rate)
            self.pre_update = pre_optimizer.minimize(self.pre_loss,
                                                     global_step=global_step)

            # kdgan train
            kdgan_losses = self.get_kdgan_losses(flags)
            self.kdgan_loss = tf.add_n(kdgan_losses,
                                       name='%s_kdgan_loss' % tch_scope)
            kdgan_optimizer = utils.get_opt(flags, self.learning_rate)
            self.kdgan_update = kdgan_optimizer.minimize(
                self.kdgan_loss, global_step=global_step)
Esempio n. 2
0
    def __init__(self, flags, is_training=True):
        self.is_training = is_training

        # None = batch_size
        self.image_ph = tf.placeholder(tf.float32,
                                       shape=(None, flags.feature_size))
        self.text_ph = tf.placeholder(tf.int64, shape=(None, None))
        self.hard_label_ph = tf.placeholder(tf.float32,
                                            shape=(None, flags.num_label))
        #self.soft_label_ph = tf.placeholder(tf.float32, shape=(None, flags.num_label))

        # None = batch_size * sample_size
        self.sample_ph = tf.placeholder(tf.int32, shape=(None, 2))
        self.reward_ph = tf.placeholder(tf.float32, shape=(None, ))

        tch_scope = 'tch'
        vocab_size = utils.get_vocab_size(flags.dataset)
        model_scope = nets_factory.arg_scopes_map[flags.model_name]
        # initializer = tf.random_uniform([vocab_size, flags.embedding_size], -0.1, 0.1)
        with tf.variable_scope(tch_scope) as scope:
            """
      with slim.arg_scope([slim.fully_connected],
          weights_regularizer=slim.l2_regularizer(flags.tch_weight_decay)):
        word_embedding = slim.variable('word_embedding',
            shape=[vocab_size, flags.embedding_size],
            # regularizer=slim.l2_regularizer(flags.tch_weight_decay),
            initializer=tf.random_uniform_initializer(-0.1, 0.1))
        # word_embedding = tf.get_variable('word_embedding', initializer=initializer)
        text_embedding = tf.nn.embedding_lookup(word_embedding, self.text_ph)
        text_embedding = tf.reduce_mean(text_embedding, axis=-2)
      """
            with slim.arg_scope(
                    model_scope(weight_decay=flags.gen_weight_decay)):
                net = self.image_ph
                net = slim.dropout(net,
                                   flags.dropout_keep_prob,
                                   is_training=is_training)

                #combined_logits = tf.concat([net, text_embedding], 1)
                #"""
                self.logits = slim.fully_connected(net,
                                                   flags.num_label,
                                                   activation_fn=None)

        self.labels = tf.nn.softmax(self.logits)

        if not is_training:
            return

        save_dict = {}
        for variable in tf.trainable_variables():
            if not variable.name.startswith(tch_scope):
                continue
            print('%-50s added to TCH saver' % variable.name)
            save_dict[variable.name] = variable
        self.saver = tf.train.Saver(save_dict)

        global_step = tf.Variable(0, trainable=False)
        train_data_size = utils.get_train_data_size(flags.dataset)
        self.learning_rate = utils.get_lr(flags, global_step, train_data_size,
                                          flags.learning_rate,
                                          flags.learning_rate_decay_factor,
                                          flags.num_epochs_per_decay,
                                          tch_scope)

        # pre train
        pre_losses = []
        pre_losses.append(
            tf.losses.sigmoid_cross_entropy(self.hard_label_ph, self.logits))
        pre_losses.extend(tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES))
        self.pre_loss = tf.add_n(pre_losses, name='%s_pre_loss' % tch_scope)
        pre_optimizer = tf.train.AdamOptimizer(self.learning_rate)
        #pre_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        self.pre_update = pre_optimizer.minimize(self.pre_loss,
                                                 global_step=global_step)

        # kdgan train
        sample_logits = tf.gather_nd(self.logits, self.sample_ph)
        kdgan_losses = [
            tf.losses.sigmoid_cross_entropy(self.reward_ph, sample_logits)
        ]
        self.kdgan_loss = tf.add_n(kdgan_losses,
                                   name='%s_kdgan_loss' % tch_scope)
        kdgan_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        self.kdgan_update = kdgan_optimizer.minimize(self.kdgan_loss,
                                                     global_step=global_step)
Esempio n. 3
0
    def __init__(self, flags, is_training=True):
        self.is_training = is_training

        self.text_ph = tf.placeholder(tf.int64, shape=(None, None))
        self.label_ph = tf.placeholder(tf.float32,
                                       shape=(None, config.num_label))

        tch_scope = 'teacher'
        vocab_size = utils.get_vocab_size(flags.dataset)
        # initializer = tf.random_uniform([vocab_size, flags.embedding_size], -0.1, 0.1)
        with tf.variable_scope(tch_scope) as scope:
            with slim.arg_scope([slim.fully_connected],
                                weights_regularizer=slim.l2_regularizer(
                                    flags.tch_weight_decay)):
                word_embedding = slim.variable(
                    'word_embedding',
                    shape=[vocab_size, flags.embedding_size],
                    # regularizer=slim.l2_regularizer(flags.tch_weight_decay),
                    initializer=tf.random_uniform_initializer(-0.1, 0.1))
                # word_embedding = tf.get_variable('word_embedding', initializer=initializer)
                text_embedding = tf.nn.embedding_lookup(
                    word_embedding, self.text_ph)
                text_embedding = tf.reduce_mean(text_embedding, axis=-2)
                self.logits = slim.fully_connected(text_embedding,
                                                   config.num_label,
                                                   activation_fn=None)

        self.labels = tf.nn.softmax(self.logits)

        if not is_training:
            return

        save_dict = {}
        for variable in tf.trainable_variables():
            if not variable.name.startswith(tch_scope):
                continue
            print('%s added to TCH saver' % variable.name)
            save_dict[variable.name] = variable
        self.saver = tf.train.Saver(save_dict)

        train_data_size = utils.get_tn_size(flags.dataset)
        global_step = tf.train.get_global_step()
        decay_steps = int(train_data_size / config.train_batch_size *
                          flags.num_epochs_per_decay)
        self.learning_rate = tf.train.exponential_decay(
            flags.init_learning_rate,
            global_step,
            decay_steps,
            flags.learning_rate_decay_factor,
            staircase=True,
            name='exponential_decay_learning_rate')

        loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(labels=self.label_ph,
                                                    logits=self.logits))
        losses = [loss]
        regularization_losses = tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES)
        losses.extend(regularization_losses)
        total_loss = tf.add_n(losses, name='total_loss')

        optimizer = tf.train.AdamOptimizer(self.learning_rate)
        self.train_op = optimizer.minimize(total_loss, global_step=global_step)

        tf.summary.scalar('total_loss', total_loss)
        self.summary_op = tf.summary.merge_all()