Beispiel #1
0
  def __init__(self, flags, is_training=True):
    self.is_training = is_training

    # None = batch_size
    self.image_ph = tf.placeholder(tf.float32, shape=(None, flags.feature_size))
    self.hard_label_ph = tf.placeholder(tf.float32, shape=(None, flags.num_label))

    # None = batch_size * sample_size
    self.gen_sample_ph = tf.placeholder(tf.int32, shape=(None, 2))
    self.gen_label_ph = tf.placeholder(tf.float32, shape=(None,))
    self.tch_sample_ph = tf.placeholder(tf.int32, shape=(None, 2))
    self.tch_label_ph = tf.placeholder(tf.float32, shape=(None,))

    self.dis_scope = dis_scope = 'dis'
    model_scope = nets_factory.arg_scopes_map[flags.image_model]
    with tf.variable_scope(dis_scope) as scope:
      with slim.arg_scope(model_scope(weight_decay=flags.image_weight_decay)):
        net = self.image_ph
        net = slim.dropout(net, flags.image_keep_prob, is_training=is_training)
        net = slim.fully_connected(net, flags.num_label, activation_fn=None)
        self.logits = net

      self.gen_rewards = self.get_rewards(self.gen_sample_ph)
      self.tch_rewards = self.get_rewards(self.tch_sample_ph)

      if not is_training:
        return

      save_dict = {}
      for variable in tf.trainable_variables():
        if not variable.name.startswith(dis_scope):
          continue
        print('%-50s added to DIS saver' % variable.name)
        save_dict[variable.name] = variable
      self.saver = tf.train.Saver(save_dict)

      self.global_step = global_step = tf.Variable(0, trainable=False)
      tn_size = utils.get_tn_size(flags.dataset)
      learning_rate = flags.dis_learning_rate
      self.learning_rate = utils.get_lr(flags, tn_size, global_step, learning_rate, dis_scope)

      # pre train
      pre_losses = self.get_pre_losses()
      pre_losses.extend(self.get_regularization_losses())
      print('#pre_losses wt regularization=%d' % (len(pre_losses)))
      self.pre_loss = tf.add_n(pre_losses, name='%s_pre_loss' % dis_scope)
      pre_optimizer = utils.get_opt(flags, self.learning_rate)
      self.pre_update = pre_optimizer.minimize(self.pre_loss, global_step=global_step)

      # gan train
      gan_losses = self.get_gan_losses(flags)
      gan_losses.extend(self.get_regularization_losses())
      print('#gan_losses wt regularization=%d' % (len(gan_losses)))
      self.gan_loss = tf.add_n(gan_losses, name='%s_gan_loss' % dis_scope)
      gan_optimizer = utils.get_opt(flags, self.learning_rate)
      self.gan_update = gan_optimizer.minimize(self.gan_loss, global_step=global_step)
Beispiel #2
0
from flags import flags
from dis_model import DIS
from gen_model import GEN
from tch_model import TCH
import data_utils

import math
import os
import pickle
import time
import numpy as np
import tensorflow as tf
from os import path
from tensorflow.contrib import slim

tn_size = utils.get_tn_size(flags.dataset)
eval_interval = int(tn_size / flags.batch_size)
print('#tn_size=%d' % (tn_size))

tn_dis = DIS(flags, is_training=True)
tn_gen = GEN(flags, is_training=True)
tn_tch = TCH(flags, is_training=True)
scope = tf.get_variable_scope()
scope.reuse_variables()
vd_dis = DIS(flags, is_training=False)
vd_gen = GEN(flags, is_training=False)
vd_tch = TCH(flags, is_training=False)

for variable in tf.trainable_variables():
  num_params = 1
  for dim in variable.shape:
Beispiel #3
0
  def __init__(self, flags, is_training=True):
    self.is_training = is_training

    # None = batch_size
    self.image_ph = tf.placeholder(tf.float32, shape=(None, flags.feature_size))
    # None = batch_size * (num_positive + num_negative)
    self.sample_ph = tf.placeholder(tf.int32, shape=(None, 2))
    # None = batch_size * (num_positive + num_negative)
    self.label_ph = tf.placeholder(tf.float32, shape=(None,))
    # None = batch_size
    self.pre_label_ph = tf.placeholder(tf.float32, shape=(None, config.num_label))

    dis_scope = 'discriminator'
    model_scope = nets_factory.arg_scopes_map[flags.model_name]
    with tf.variable_scope(dis_scope) as scope:
      with slim.arg_scope(model_scope(weight_decay=flags.dis_weight_decay)):
        net = self.image_ph
        net = slim.dropout(net, flags.dropout_keep_prob, 
            is_training=is_training)
        net = slim.fully_connected(net, config.num_label,
            activation_fn=None)
        self.logits = net

    sample_logits = tf.gather_nd(self.logits, self.sample_ph)
    # self.rewards = 2 * (tf.sigmoid(sample_logits) - 0.5)
    # self.rewards = tf.sigmoid(sample_logits)

    reward_logits = self.logits
    # reward_logits = 2 * (tf.sigmoid(reward_logits) - 0.5)
    # reward_logits -= tf.reduce_mean(reward_logits, 1, keep_dims=True)
    # reward_logits -= tf.reduce_mean(reward_logits, 1, keep_dims=True)
    # reward_logits = 2 * (tf.sigmoid(reward_logits) - 0.5)
    reward_logits = tf.sigmoid(reward_logits)
    # reward_logits -= tf.reduce_mean(reward_logits, 1, keep_dims=True)
    self.rewards = tf.gather_nd(reward_logits, self.sample_ph)

    if not is_training:
      return

    save_dict = {}
    for variable in tf.trainable_variables():
      if not variable.name.startswith(dis_scope):
        continue
      print('%s added to DIS saver' % variable.name)
      save_dict[variable.name] = variable
    self.saver = tf.train.Saver(save_dict)

    train_data_size = utils.get_tn_size(flags.dataset)
    global_step = tf.train.get_global_step()
    decay_steps = int(train_data_size / config.train_batch_size * flags.num_epochs_per_decay)
    self.learning_rate = tf.train.exponential_decay(flags.init_learning_rate,
        global_step, decay_steps, flags.learning_rate_decay_factor,
        staircase=True, name='exponential_decay_learning_rate')

    # pretrain discriminator
    losses = []
    losses.append(tf.losses.sigmoid_cross_entropy(self.pre_label_ph, self.logits))
    losses.extend(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    self.pre_loss = tf.add_n(losses, name='dis_pre_loss')
    optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
    self.pre_train_op = optimizer.minimize(self.pre_loss, global_step=global_step)

    losses = []
    losses.append(tf.losses.sigmoid_cross_entropy(self.label_ph, sample_logits))
    regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    losses.extend(regularization_losses)
    self.gan_loss = tf.add_n(losses, name='dis_gan_loss')
    optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
    self.train_op = optimizer.minimize(self.gan_loss, global_step=global_step)
Beispiel #4
0
    def __init__(self, flags, is_training=True):
        self.is_training = is_training

        # None = batch_size
        self.image_ph = tf.placeholder(tf.float32,
                                       shape=(None, flags.feature_size))
        self.hard_label_ph = tf.placeholder(tf.float32,
                                            shape=(None, flags.num_label))
        self.soft_logit_ph = tf.placeholder(tf.float32,
                                            shape=(None, flags.num_label))

        # None = batch_size * sample_size
        self.sample_ph = tf.placeholder(tf.int32, shape=(None, 2))
        self.reward_ph = tf.placeholder(tf.float32, shape=(None, ))

        self.gen_scope = gen_scope = 'gen'
        model_scope = nets_factory.arg_scopes_map[flags.image_model]
        with tf.variable_scope(gen_scope) as scope:
            with slim.arg_scope(
                    model_scope(weight_decay=flags.image_weight_decay)):
                net = self.image_ph
                net = slim.dropout(net,
                                   flags.image_keep_prob,
                                   is_training=is_training)
                net = slim.fully_connected(net,
                                           flags.num_label,
                                           activation_fn=None)
                self.logits = net

            self.labels = tf.nn.softmax(self.logits)

            if not is_training:
                return

            save_dict, var_list = {}, []
            for variable in tf.trainable_variables():
                if not variable.name.startswith(gen_scope):
                    continue
                print('%-50s added to GEN saver' % variable.name)
                save_dict[variable.name] = variable
                var_list.append(variable)
            self.saver = tf.train.Saver(save_dict)

            self.global_step = global_step = tf.Variable(0, trainable=False)
            tn_size = utils.get_tn_size(flags.dataset)
            learning_rate = flags.gen_learning_rate
            self.learning_rate = utils.get_lr(flags, tn_size, global_step,
                                              learning_rate, gen_scope)

            # pre train
            pre_losses = self.get_pre_losses()
            print('#pre_losses wo regularization=%d' % (len(pre_losses)))
            pre_losses.extend(self.get_regularization_losses())
            print('#pre_losses wt regularization=%d' % (len(pre_losses)))
            self.pre_loss = tf.add_n(pre_losses,
                                     name='%s_pre_loss' % gen_scope)
            pre_optimizer = utils.get_opt(flags, self.learning_rate)
            self.pre_update = pre_optimizer.minimize(self.pre_loss,
                                                     global_step=global_step)

            # kd train
            kd_losses = self.get_kd_losses(flags)
            print('#kd_losses wo regularization=%d' % (len(kd_losses)))
            self.kd_loss = tf.add_n(kd_losses, name='%s_kd_loss' % gen_scope)
            kd_optimizer = utils.get_opt(flags, self.learning_rate)
            self.kd_update = kd_optimizer.minimize(self.kd_loss,
                                                   global_step=global_step)

            # gan train
            gan_losses = self.get_gan_losses()
            print('#gan_losses wo regularization=%d' % (len(gan_losses)))
            gan_losses.extend(self.get_regularization_losses())
            print('#gan_losses wt regularization=%d' % (len(gan_losses)))
            self.gan_loss = tf.add_n(gan_losses,
                                     name='%s_gan_loss' % gen_scope)
            gan_optimizer = utils.get_opt(flags, self.learning_rate)
            self.gan_update = gan_optimizer.minimize(self.gan_loss,
                                                     global_step=global_step)

            # kdgan train
            kdgan_losses = self.get_kdgan_losses(flags)
            print('#kdgan_losses wo regularization=%d' % (len(kdgan_losses)))
            kdgan_losses.extend(self.get_regularization_losses())
            print('#kdgan_losses wt regularization=%d' % (len(kdgan_losses)))
            self.kdgan_loss = tf.add_n(kdgan_losses,
                                       name='%s_kdgan_loss' % gen_scope)
            kdgan_optimizer = utils.get_opt(flags, self.learning_rate)
            # self.kdgan_update = kdgan_optimizer.minimize(self.kdgan_loss, global_step=global_step)
            gvs = kdgan_optimizer.compute_gradients(self.kdgan_loss, var_list)
            cgvs = [(tf.clip_by_norm(gv[0], config.max_norm), gv[1])
                    for gv in gvs]
            self.kdgan_update = kdgan_optimizer.apply_gradients(
                cgvs, global_step=global_step)
Beispiel #5
0
    def __init__(self, flags, is_training=True):
        self.is_training = is_training

        # None = batch_size
        self.image_ph = tf.placeholder(tf.float32,
                                       shape=(None, flags.feature_size))
        self.text_ph = tf.placeholder(tf.int64, shape=(None, None))
        self.hard_label_ph = tf.placeholder(tf.float32,
                                            shape=(None, flags.num_label))

        # None = batch_size * sample_size
        self.sample_ph = tf.placeholder(tf.int32, shape=(None, 2))
        self.reward_ph = tf.placeholder(tf.float32, shape=(None, ))

        self.tch_scope = tch_scope = 'tch'
        model_scope = nets_factory.arg_scopes_map[flags.image_model]
        vocab_size = utils.get_vocab_size(flags.dataset)
        with tf.variable_scope(tch_scope) as scope:
            with slim.arg_scope(
                    model_scope(weight_decay=flags.image_weight_decay)):
                iembed = self.image_ph
                iembed = slim.dropout(iembed,
                                      flags.image_keep_prob,
                                      is_training=is_training)

            with slim.arg_scope([slim.fully_connected],
                                weights_regularizer=slim.l2_regularizer(
                                    flags.text_weight_decay)):
                wembed = slim.variable(
                    'wembed',
                    shape=[vocab_size, flags.embedding_size],
                    initializer=tf.random_uniform_initializer(-0.1, 0.1))
                tembed = tf.nn.embedding_lookup(wembed, self.text_ph)
                tembed = tf.reduce_mean(tembed, axis=-2)

            with slim.arg_scope([slim.fully_connected],
                                weights_regularizer=slim.l2_regularizer(
                                    flags.tch_weight_decay),
                                biases_initializer=tf.zeros_initializer()):
                # cembed = tf.concat([tembed], 1)
                cembed = tf.concat([iembed, tembed], 1)
                self.logits = slim.fully_connected(cembed,
                                                   flags.num_label,
                                                   activation_fn=None)

            self.labels = tf.nn.softmax(self.logits)

            if not is_training:
                return

            save_dict = {}
            for variable in tf.trainable_variables():
                if not variable.name.startswith(tch_scope):
                    continue
                print('%-50s added to TCH saver' % variable.name)
                save_dict[variable.name] = variable
            self.saver = tf.train.Saver(save_dict)

            self.global_step = global_step = tf.Variable(0, trainable=False)
            tn_size = utils.get_tn_size(flags.dataset)
            learning_rate = flags.tch_learning_rate
            self.learning_rate = utils.get_lr(flags, tn_size, global_step,
                                              learning_rate, tch_scope)

            # pre train
            pre_losses = self.get_pre_losses()
            self.pre_loss = tf.add_n(pre_losses,
                                     name='%s_pre_loss' % tch_scope)
            pre_losses.extend(self.get_regularization_losses())
            print('#pre_losses wt regularization=%d' % (len(pre_losses)))
            pre_optimizer = utils.get_opt(flags, self.learning_rate)
            self.pre_update = pre_optimizer.minimize(self.pre_loss,
                                                     global_step=global_step)

            # kdgan train
            kdgan_losses = self.get_kdgan_losses(flags)
            self.kdgan_loss = tf.add_n(kdgan_losses,
                                       name='%s_kdgan_loss' % tch_scope)
            kdgan_optimizer = utils.get_opt(flags, self.learning_rate)
            self.kdgan_update = kdgan_optimizer.minimize(
                self.kdgan_loss, global_step=global_step)
Beispiel #6
0
    def __init__(self, flags, is_training=True):
        self.is_training = is_training

        self.image_ph = tf.placeholder(tf.float32,
                                       shape=(None, flags.feature_size))
        self.label_ph = tf.placeholder(tf.float32,
                                       shape=(None, config.num_label))

        gen_scope = 'generator'
        model_scope = nets_factory.arg_scopes_map[flags.model_name]
        with tf.variable_scope(gen_scope) as scope:
            with slim.arg_scope(
                    model_scope(weight_decay=flags.gen_weight_decay)):
                net = self.image_ph
                net = slim.dropout(net,
                                   flags.dropout_keep_prob,
                                   is_training=is_training)
                net = slim.fully_connected(net,
                                           config.num_label,
                                           activation_fn=None)
                self.logits = net

        self.labels = tf.nn.softmax(self.logits)

        if not is_training:
            return

        save_dict = {}
        for variable in tf.trainable_variables():
            if not variable.name.startswith(gen_scope):
                continue
            print('%s added to GEN saver' % variable.name)
            save_dict[variable.name] = variable
        self.saver = tf.train.Saver(save_dict)

        train_data_size = utils.get_tn_size(flags.dataset)
        global_step = tf.train.get_global_step()
        decay_steps = int(train_data_size / config.train_batch_size *
                          flags.num_epochs_per_decay)
        self.learning_rate = tf.train.exponential_decay(
            flags.init_learning_rate,
            global_step,
            decay_steps,
            flags.learning_rate_decay_factor,
            staircase=True,
            name='exponential_decay_learning_rate')

        # pretrain generator
        losses = []
        losses.append(
            tf.losses.sigmoid_cross_entropy(self.label_ph, self.logits))
        regularization_losses = tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES)
        losses.extend(regularization_losses)
        self.pre_loss = tf.add_n(losses, name='pre_loss')
        optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        self.train_op = optimizer.minimize(self.pre_loss,
                                           global_step=global_step)

        # knowledge distillation
        self.hard_label_ph = tf.placeholder(tf.float32,
                                            shape=(None, config.num_label))
        self.soft_label_ph = tf.placeholder(tf.float32,
                                            shape=(None, config.num_label))
        hard_loss = tf.losses.sigmoid_cross_entropy(self.hard_label_ph,
                                                    self.logits)
        soft_loss = tf.nn.l2_loss(
            tf.nn.softmax(self.logits) -
            tf.nn.softmax(self.soft_label_ph / flags.temperature))
        kd_loss = (1.0 - flags.beta) * hard_loss + flags.beta * soft_loss
        kd_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        self.kd_train_op = kd_optimizer.minimize(kd_loss,
                                                 global_step=global_step)

        # generative adversarial network
        self.sample_ph = tf.placeholder(tf.int32, shape=(None, 2))
        self.reward_ph = tf.placeholder(tf.float32, shape=(None, ))
        sample_logits = tf.gather_nd(self.logits, self.sample_ph)
        # gan_loss = -tf.reduce_mean(self.reward_ph * sample_logits)
        gan_loss = tf.losses.sigmoid_cross_entropy(self.reward_ph,
                                                   sample_logits)
        gan_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        self.gan_train_op = gan_optimizer.minimize(gan_loss,
                                                   global_step=global_step)
Beispiel #7
0
    def __init__(self, flags, is_training=True):
        self.is_training = is_training

        # None = batch_size
        self.image_ph = tf.placeholder(tf.float32,
                                       shape=(None, flags.feature_size))
        self.hard_label_ph = tf.placeholder(tf.float32,
                                            shape=(None, flags.num_label))
        self.soft_label_ph = tf.placeholder(tf.float32,
                                            shape=(None, flags.num_label))

        # None = batch_size * sample_size
        self.sample_ph = tf.placeholder(tf.int32, shape=(None, 2))
        self.reward_ph = tf.placeholder(tf.float32, shape=(None, ))

        self.gen_scope = gen_scope = 'gen'
        model_scope = nets_factory.arg_scopes_map[flags.image_model]
        with tf.variable_scope(gen_scope) as scope:
            with slim.arg_scope(
                    model_scope(weight_decay=flags.image_weight_decay)):
                net = self.image_ph
                net = slim.dropout(net,
                                   flags.image_keep_prob,
                                   is_training=is_training)
                net = slim.fully_connected(net,
                                           flags.num_label,
                                           activation_fn=None)
                self.logits = net

        self.labels = tf.nn.softmax(self.logits)

        if not is_training:
            return

        save_dict = {}
        for variable in tf.trainable_variables():
            if not variable.name.startswith(gen_scope):
                continue
            print('%-50s added to GEN saver' % variable.name)
            save_dict[variable.name] = variable
        self.saver = tf.train.Saver(save_dict)

        self.global_step = global_step = tf.Variable(0, trainable=False)
        tn_size = utils.get_tn_size(flags.dataset)
        self.learning_rate = utils.get_lr(flags, tn_size, global_step,
                                          flags.gen_learning_rate, gen_scope)

        # pre train
        pre_losses = []
        pre_losses.append(
            tf.losses.sigmoid_cross_entropy(self.hard_label_ph, self.logits))
        pre_losses.extend(tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES))
        self.pre_loss = tf.add_n(pre_losses, name='%s_pre_loss' % gen_scope)
        pre_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        self.pre_update = pre_optimizer.minimize(self.pre_loss,
                                                 global_step=global_step)

        # kd train
        kd_losses = self.get_kd_losses(flags)
        self.kd_loss = tf.add_n(kd_losses, name='%s_kd_loss' % gen_scope)
        kd_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        self.kd_update = kd_optimizer.minimize(self.kd_loss,
                                               global_step=global_step)

        # gan train
        gan_losses = self.get_gan_losses(flags)
        gan_losses.extend(tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES))
        self.gan_loss = tf.add_n(gan_losses, name='%s_gan_loss' % gen_scope)
        gan_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        self.gan_update = gan_optimizer.minimize(self.gan_loss,
                                                 global_step=global_step)

        # kdgan train
        kdgan_losses = self.get_kd_losses(flags) + self.get_gan_losses(flags)
        self.kdgan_loss = tf.add_n(kdgan_losses,
                                   name='%s_kdgan_loss' % gen_scope)
        kdgan_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        self.kdgan_update = kdgan_optimizer.minimize(self.kdgan_loss,
                                                     global_step=global_step)
Beispiel #8
0
  def __init__(self, flags, is_training=True):
    self.is_training = is_training

    # None = batch_size
    self.text_ph = tf.placeholder(tf.int64, shape=(None, None))
    self.hard_label_ph = tf.placeholder(tf.float32, shape=(None, flags.num_label))

    # None = batch_size * sample_size
    self.sample_ph = tf.placeholder(tf.int32, shape=(None, 2))
    self.reward_ph = tf.placeholder(tf.float32, shape=(None,))

    tch_scope = 'tch'
    vocab_size = utils.get_vocab_size(flags.dataset)
    # initializer = tf.random_uniform([vocab_size, flags.embedding_size], -0.1, 0.1)
    with tf.variable_scope(tch_scope) as scope:
      with slim.arg_scope([slim.fully_connected],
          weights_regularizer=slim.l2_regularizer(flags.text_weight_decay)):
        word_embedding = slim.variable('word_embedding',
            shape=[vocab_size, flags.embedding_size],
            # regularizer=slim.l2_regularizer(flags.tch_weight_decay),
            initializer=tf.random_uniform_initializer(-0.1, 0.1))
        # word_embedding = tf.get_variable('word_embedding', initializer=initializer)
        text_embedding = tf.nn.embedding_lookup(word_embedding, self.text_ph)
        text_embedding = tf.reduce_mean(text_embedding, axis=-2)
        self.logits = slim.fully_connected(text_embedding, flags.num_label,
                  activation_fn=None)

    self.labels = tf.nn.softmax(self.logits)

    if not is_training:
      return

    save_dict = {}
    for variable in tf.trainable_variables():
      if not variable.name.startswith(tch_scope):
        continue
      print('%-50s added to TCH saver' % variable.name)
      save_dict[variable.name] = variable
    self.saver = tf.train.Saver(save_dict)

    self.global_step = global_step = tf.Variable(0, trainable=False)
    tn_size = utils.get_tn_size(flags.dataset)
    self.learning_rate = utils.get_lr(flags, 
        tn_size,
        global_step,
        flags.tch_learning_rate,
        tch_scope)

    # pre train
    pre_losses = []
    pre_losses.append(tf.losses.sigmoid_cross_entropy(self.hard_label_ph, self.logits))
    pre_losses.extend(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    self.pre_loss = tf.add_n(pre_losses, name='%s_pre_loss' % tch_scope)
    pre_optimizer = tf.train.AdamOptimizer(self.learning_rate)
    self.pre_update = pre_optimizer.minimize(self.pre_loss, global_step=global_step)

    # kdgan train
    sample_logits = tf.gather_nd(self.logits, self.sample_ph)
    kdgan_losses = [tf.losses.sigmoid_cross_entropy(self.reward_ph, sample_logits)]
    self.kdgan_loss = tf.add_n(kdgan_losses, name='%s_kdgan_loss' % tch_scope)
    kdgan_optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
    self.kdgan_update = kdgan_optimizer.minimize(self.kdgan_loss, global_step=global_step)
Beispiel #9
0
    def __init__(self, flags, is_training=True):
        self.is_training = is_training

        self.text_ph = tf.placeholder(tf.int64, shape=(None, None))
        self.label_ph = tf.placeholder(tf.float32,
                                       shape=(None, config.num_label))

        tch_scope = 'teacher'
        vocab_size = utils.get_vocab_size(flags.dataset)
        # initializer = tf.random_uniform([vocab_size, flags.embedding_size], -0.1, 0.1)
        with tf.variable_scope(tch_scope) as scope:
            with slim.arg_scope([slim.fully_connected],
                                weights_regularizer=slim.l2_regularizer(
                                    flags.tch_weight_decay)):
                word_embedding = slim.variable(
                    'word_embedding',
                    shape=[vocab_size, flags.embedding_size],
                    # regularizer=slim.l2_regularizer(flags.tch_weight_decay),
                    initializer=tf.random_uniform_initializer(-0.1, 0.1))
                # word_embedding = tf.get_variable('word_embedding', initializer=initializer)
                text_embedding = tf.nn.embedding_lookup(
                    word_embedding, self.text_ph)
                text_embedding = tf.reduce_mean(text_embedding, axis=-2)
                self.logits = slim.fully_connected(text_embedding,
                                                   config.num_label,
                                                   activation_fn=None)

        self.labels = tf.nn.softmax(self.logits)

        if not is_training:
            return

        save_dict = {}
        for variable in tf.trainable_variables():
            if not variable.name.startswith(tch_scope):
                continue
            print('%s added to TCH saver' % variable.name)
            save_dict[variable.name] = variable
        self.saver = tf.train.Saver(save_dict)

        train_data_size = utils.get_tn_size(flags.dataset)
        global_step = tf.train.get_global_step()
        decay_steps = int(train_data_size / config.train_batch_size *
                          flags.num_epochs_per_decay)
        self.learning_rate = tf.train.exponential_decay(
            flags.init_learning_rate,
            global_step,
            decay_steps,
            flags.learning_rate_decay_factor,
            staircase=True,
            name='exponential_decay_learning_rate')

        loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(labels=self.label_ph,
                                                    logits=self.logits))
        losses = [loss]
        regularization_losses = tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES)
        losses.extend(regularization_losses)
        total_loss = tf.add_n(losses, name='total_loss')

        optimizer = tf.train.AdamOptimizer(self.learning_rate)
        self.train_op = optimizer.minimize(total_loss, global_step=global_step)

        tf.summary.scalar('total_loss', total_loss)
        self.summary_op = tf.summary.merge_all()