コード例 #1
0
ファイル: train.py プロジェクト: zqhfpjlswsqy/google-research
def train():
    batch_images = tf.placeholder(tf.float32, [None, 128, 128, 3])
    batch_labels = tf.placeholder(tf.int32, [
        None,
    ])
    is_training_ph = tf.placeholder(tf.bool)
    lr_ph = tf.placeholder(tf.float32)

    states_rl = tf.placeholder(tf.float32, [None, 2])
    actions_rl = tf.placeholder(tf.int32, [
        None,
    ])
    values_rl = tf.placeholder(tf.float32, [
        None,
    ])
    is_training_rl = tf.placeholder(tf.bool)
    lr_rl = tf.placeholder(tf.float32)

    states_c = tf.placeholder(tf.float32, [None, 7])
    values_c = tf.placeholder(tf.float32, [
        None,
    ])
    is_training_c = tf.placeholder(tf.bool)
    lr_c = tf.placeholder(tf.float32)

    phs = {
        'batch_images': batch_images,
        'batch_labels': batch_labels,
        'is_training_ph': is_training_ph,
        'lr_ph': lr_ph
    }

    phrl = {
        'states_rl': states_rl,
        'actions_rl': actions_rl,
        'values_rl': values_rl,
        'is_training_rl': is_training_rl,
        'lr_rl': lr_rl
    }

    phc = {
        'states_c': states_c,
        'values_c': values_c,
        'is_training_c': is_training_c,
        'lr_c': lr_c
    }
    with tf.Session() as sess:
        vgg_loss, vgg_acc, vgg_ce, vgg_prob, vgg_update, vgg_pred, vgg_vars = vgg_graph(
            sess, phs)
        rl_loss, rl_prob, rl_update, rl_vars = rl_graph(sess, phrl)
        c_loss, c_value, c_update, c_vars = c_graph(sess, phc)
        vgg_init = tf.variables_initializer(var_list=vgg_vars)
        saver = tf.train.Saver(vgg_vars)
        all_saver = tf.train.Saver()
        init = tf.global_variables_initializer()
        sess.run(init)

        # for epoch in range(4):
        #   for t in range(train_iters):
        #     if t % 50==0: print("pretrain:", t)
        #     tr_images, tr_labels = sess.run([train_images,train_labels])
        #     pre_dict = {phs['batch_images']: tr_images,
        #           phs['batch_labels']: tr_labels,
        #           phs['is_training_ph']: True}
        #     sess.run(vgg_update, feed_dict=pre_dict)
        # saver.save(sess,LOG_PATH+'/vgg.ckpt')
        # valid_acc = 0.0
        # y_pred =[]
        # y_label = []
        # y_att = []
        # for k in range(valid_iters):
        #   va_images, va_labels, va_att = sess.run([valid_images, valid_labels, valid_att])
        #   valid_dict = {phs['batch_images']: va_images,
        #           phs['batch_labels']: va_labels,
        #           phs['is_training_ph']: False}
        #   batch_acc, batch_pred = sess.run([vgg_acc,vgg_pred], feed_dict=valid_dict)
        #   valid_acc += batch_acc
        #   y_pred += batch_pred.tolist()
        #   y_label += va_labels.tolist()
        #   y_att += va_att.tolist()
        # valid_acc = valid_acc / float(valid_iters)
        # valid_eo = data.cal_eo(y_att, y_label, y_pred)
        # log_string('====pretrain: valid_acc=%.4f, valid_eo=%.4f' % (valid_acc, valid_eo[-1]))
        # print(valid_eo)

        va_images, va_labels, va_att = sess.run(
            [valid_images, valid_labels, valid_att])
        for i in range(N_EPISODE):
            sess.run(vgg_init)
            # saver.restore(sess,LOG_PATH+'/vgg.ckpt')
            train_loss = []
            for j in range(train_iters * 20):
                tr_images, tr_labels, tr_att = sess.run(
                    [train_images, train_labels, train_att])
                fa_images, fa_labels, fa_att = sess.run(
                    [fake_images, fake_labels, fake_att])

                train_dict = {
                    phs['batch_images']: tr_images,
                    phs['batch_labels']: tr_labels,
                    phs['is_training_ph']: False
                }
                ce, acc, prob, pred = sess.run(
                    [vgg_ce, vgg_acc, vgg_prob, vgg_pred],
                    feed_dict=train_dict)
                ce = np.clip(ce, 0, 10) / 10.0
                train_loss.append(np.mean(ce))
                model_stat = list(data.cal_eo(tr_att, tr_labels,
                                              pred))  #shape [5,]
                model_stat.append(np.mean(ce))
                model_stat.append(j / (train_iters * 20))
                # model_stat.append(np.mean(train_loss))
                c_state = np.array(model_stat)[np.newaxis, :]

                # model_stat = np.tile(model_stat,(BATCH_SIZE,1))
                state = np.concatenate(
                    (tr_labels[:, np.newaxis], tr_att[:, np.newaxis]), axis=1)

                rl_dict = {
                    phrl['states_rl']: state,
                    phrl['is_training_rl']: False
                }
                action = choose_action(sess.run(rl_prob, feed_dict=rl_dict))

                c_dict = {
                    phc['states_c']: c_state,
                    phc['is_training_c']: False
                }
                base = sess.run(c_value, feed_dict=c_dict)

                bool_train = list(map(bool, action))
                bool_fake = list(map(bool, 1 - action))
                co_images = np.concatenate(
                    (tr_images[bool_train], fa_images[bool_fake]), axis=0)
                co_labels = np.concatenate(
                    (tr_labels[bool_train], fa_labels[bool_fake]), axis=0)

                update_dict = {
                    phs['batch_images']: co_images,
                    phs['batch_labels']: co_labels,
                    phs['is_training_ph']: True
                }
                _, ce, acc = sess.run([vgg_update, vgg_ce, vgg_acc],
                                      feed_dict=update_dict)

                valid_dict = {
                    phs['batch_images']: va_images,
                    phs['batch_labels']: va_labels,
                    phs['is_training_ph']: False
                }
                valid_acc, y_pred = sess.run([vgg_acc, vgg_pred],
                                             feed_dict=valid_dict)
                valid_eo = data.cal_eo(va_att, va_labels, y_pred)
                if valid_eo[-1] <= 0.05:
                    value = -2
                else:
                    value = -np.log(valid_eo[-1])
                reward = value - base[0]

                c_dict = {
                    phc['states_c']: c_state,
                    phc['values_c']: [value],
                    phc['is_training_c']: True
                }
                _, cri_loss = sess.run([c_update, c_loss], feed_dict=c_dict)

                final_reward = np.repeat(reward, BATCH_SIZE)
                learn_dict = {
                    phrl['states_rl']: state,
                    phrl['actions_rl']: action,
                    phrl['values_rl']: final_reward,
                    phrl['is_training_rl']: True
                }
                sess.run(rl_update, feed_dict=learn_dict)

                if j % 10 == 0:
                    log_string(
                        '====epoch_%d====iter_%d: student_loss=%.4f, train_acc=%.4f'
                        % (i, j, np.mean(ce), acc))
                    log_string(
                        '===============: critic_loss=%.4f, reward=%.4f, valid_acc=%.4f, valid_eo=%.4f'
                        % (cri_loss, reward, valid_acc, valid_eo[-1]))
                    print('eo: ', valid_eo[0], valid_eo[1])
                    print('eo: ', valid_eo[2], valid_eo[3])
                    print(action, np.sum(action))

        all_saver.save(sess, LOG_PATH + '/all.ckpt')
コード例 #2
0
def train():
  batch_images = tf.placeholder(tf.float32,[None,128,128,3])
  batch_labels = tf.placeholder(tf.int32,[None,])
  is_training_ph = tf.placeholder(tf.bool)
  lr_ph = tf.placeholder(tf.float32)

  states_rl = tf.placeholder(tf.float32,[None,11])
  actions_rl = tf.placeholder(tf.int32,[None,])
  values_rl = tf.placeholder(tf.float32,[None,])
  is_training_rl = tf.placeholder(tf.bool)
  lr_rl = tf.placeholder(tf.float32)

  phs = {'batch_images': batch_images,
       'batch_labels': batch_labels,
       'is_training_ph': is_training_ph,
       'lr_ph': lr_ph}

  phrl = {'states_rl': states_rl,
       'actions_rl': actions_rl,
       'values_rl': values_rl,
       'is_training_rl': is_training_rl,
       'lr_rl': lr_rl}

  with tf.Session() as sess:
    vgg_loss, vgg_acc, vgg_ce, vgg_prob, vgg_update, vgg_pred, vgg_vars = vgg_graph(sess, phs)
    rl_loss, rl_prob, rl_update, rl_vars = rl_graph(sess, phrl)
    vgg_init = tf.variables_initializer(var_list=vgg_vars)
    saver = tf.train.Saver(vgg_vars)
    all_saver = tf.train.Saver()
    init = tf.global_variables_initializer()
    sess.run(init)




    for i in range(N_EPISODE):
      # sess.run(vgg_init)
      all_saver.restore(sess,LOG_PATH+'/all.ckpt')
      saver.restore(sess,LOG_PATH+'/vgg.ckpt')
      # state_list = []
      # action_list = []
      # reward_list = []
      for j in range(train_iters*20):
        tr_images, tr_labels, tr_att = sess.run([train_images,train_labels, train_att])
        fa_images, fa_labels, fa_att = sess.run([fake_images,fake_labels, fake_att])

        train_dict = {phs['batch_images']: tr_images,
                phs['batch_labels']: tr_labels,
                phs['is_training_ph']: False}
        ce, acc, prob, pred = sess.run([vgg_ce, vgg_acc, vgg_prob, vgg_pred], feed_dict=train_dict)
        ce = np.clip(ce, 0, 10)/10.0
        model_stat = list(data.cal_eo(tr_att, tr_labels, pred))
        model_stat.append(np.mean(ce))
        model_stat = np.tile(model_stat,(BATCH_SIZE,1))
        state = np.concatenate((tr_labels[:, np.newaxis], tr_att[:, np.newaxis], prob, ce[:, np.newaxis], model_stat), axis=1)



        rl_dict = {phrl['states_rl']: state,
               phrl['is_training_rl']: False}
        action = choose_action(sess.run(rl_prob, feed_dict=rl_dict))



        bool_train = list(map(bool,action))
        bool_fake = list(map(bool,1-action))
        co_images = np.concatenate((tr_images[bool_train],fa_images[bool_fake]),axis=0)
        co_labels = np.concatenate((tr_labels[bool_train],fa_labels[bool_fake]),axis=0)


        update_dict = {phs['batch_images']: co_images,
                phs['batch_labels']: co_labels,
                phs['is_training_ph']: True}
        _, ce, acc = sess.run([vgg_update, vgg_ce, vgg_acc], feed_dict=update_dict)


        if j % 100 == 0:
          print('====epoch_%d====iter_%d: loss=%.4f, train_acc=%.4f' % (i, j, np.mean(ce), acc))
          print(action, np.sum(action))


      valid_acc = 0.0
      y_pred =[]
      y_label = []
      y_att = []
      for k in range(valid_iters):
        va_images, va_labels, va_att = sess.run([valid_images, valid_labels, valid_att])
        valid_dict = {phs['batch_images']: va_images,
                phs['batch_labels']: va_labels,
                phs['is_training_ph']: False}
        batch_acc, batch_pred = sess.run([vgg_acc,vgg_pred], feed_dict=valid_dict)
        valid_acc += batch_acc
        y_pred += batch_pred.tolist()
        y_label += va_labels.tolist()
        y_att += va_att.tolist()
      valid_acc = valid_acc / float(valid_iters)
      valid_eo = data.cal_eo(y_att, y_label, y_pred)
      log_string('====epoch_%d: valid_acc=%.4f, valid_eo=%.4f' % (i, valid_acc, valid_eo[-1]))
      print('eo: ',valid_eo[0],valid_eo[1])
      print('eo: ',valid_eo[2],valid_eo[3])
コード例 #3
0
                    lr_ph: lr,
                    is_training: True
                })
            if j % 50 == 0:
                print('====epoch_%d====iter_%d: loss=%.4f, train_acc=%.4f' %
                      (i, j, loss, acc))

        valid_acc = 0.0
        y_pred = []
        y_label = []
        y_att = []
        for k in range(valid_iters):
            va_images, va_labels, va_att = sess.run(
                [valid_images, valid_labels, valid_att])
            batch_acc, batch_pred = sess.run(
                [acc_op, Y_pred], {
                    batch_images: va_images,
                    batch_labels: va_labels,
                    is_training: False
                })
            valid_acc += batch_acc
            y_pred += batch_pred.tolist()
            y_label += va_labels.tolist()
            y_att += va_att.tolist()
        valid_acc = valid_acc / float(valid_iters)
        valid_eo = data.cal_eo(y_att, y_label, y_pred)
        print('====epoch_%d: valid_acc=%.4f, valid_eo=%.4f' %
              (i, valid_acc, valid_eo[-1]))
        print('eo: ', valid_eo[0], valid_eo[1])
        print('eo: ', valid_eo[2], valid_eo[3])
コード例 #4
0
def train():
  batch_images = tf.placeholder(tf.float32,[None,128,128,3])
  batch_labels = tf.placeholder(tf.int32,[None,])
  is_training_ph = tf.placeholder(tf.bool)
  lr_ph = tf.placeholder(tf.float32)

  states_rl = tf.placeholder(tf.float32,[None,11])
  actions_rl = tf.placeholder(tf.int32,[None,])
  values_rl = tf.placeholder(tf.float32,[None,])
  is_training_rl = tf.placeholder(tf.bool)
  lr_rl = tf.placeholder(tf.float32)

  phs = {'batch_images': batch_images,
       'batch_labels': batch_labels,
       'is_training_ph': is_training_ph,
       'lr_ph': lr_ph}

  phrl = {'states_rl': states_rl,
       'actions_rl': actions_rl,
       'values_rl': values_rl,
       'is_training_rl': is_training_rl,
       'lr_rl': lr_rl}

  with tf.Session() as sess:
    # tf.reset_default_graph()
    vgg_loss, vgg_acc, vgg_ce, vgg_prob, vgg_update, vgg_pred, vgg_vars = vgg_graph(sess, phs)
    rl_loss, rl_prob, rl_update, rl_vars = rl_graph(sess, phrl)
    vgg_init = tf.variables_initializer(var_list=vgg_vars)
    saver = tf.train.Saver(vgg_vars)
    all_saver = tf.train.Saver()
    init = tf.global_variables_initializer()
    sess.run(init)

    #####################  pre-train student model  #####################
    for epoch in range(4):
      for t in range(train_iters):
        if t % 50==0: print("pretrain:", t)
        tr_images, tr_labels = sess.run([train_images,train_labels])
        pre_dict = {phs['batch_images']: tr_images,
              phs['batch_labels']: tr_labels,
              phs['is_training_ph']: True}
        sess.run(vgg_update, feed_dict=pre_dict)
    saver.save(sess,LOG_PATH+'/vgg.ckpt')
    valid_acc = 0.0
    y_pred =[]
    y_label = []
    y_att = []
    for k in range(valid_iters):
      va_images, va_labels, va_att = sess.run([valid_images, valid_labels, valid_att])
      valid_dict = {phs['batch_images']: va_images,
              phs['batch_labels']: va_labels,
              phs['is_training_ph']: False}
      batch_acc, batch_pred = sess.run([vgg_acc,vgg_pred], feed_dict=valid_dict)
      valid_acc += batch_acc
      y_pred += batch_pred.tolist()
      y_label += va_labels.tolist()
      y_att += va_att.tolist()
    valid_acc = valid_acc / float(valid_iters)
    valid_eo = data.cal_eo(y_att, y_label, y_pred)
    log_string('====pretrain: valid_acc=%.4f, valid_eo=%.4f' % (valid_acc, valid_eo[-1]))
    print(valid_eo)


    #####################  train teacher model  #####################
    for i in range(N_EPISODE):
      # sess.run(vgg_init)
      saver.restore(sess,LOG_PATH+'/vgg.ckpt')
      state_list = []
      action_list = []
      reward_list = []
      for j in range(train_iters*20):
        tr_images, tr_labels, tr_att = sess.run([train_images,train_labels, train_att])
        fa_images, fa_labels, fa_att = sess.run([fake_images,fake_labels, fake_att])
        # va_images, va_labels, va_att = sess.run([valid_images,valid_labels, valid_att])

        #####################  generate state info from student model & data  #####################
        train_dict = {phs['batch_images']: tr_images,
                phs['batch_labels']: tr_labels,
                phs['is_training_ph']: False}
        ce, acc, prob, pred = sess.run([vgg_ce, vgg_acc, vgg_prob, vgg_pred], feed_dict=train_dict)
        ce = np.clip(ce, 0, 10)/10.0
        model_stat = list(data.cal_eo(tr_att, tr_labels, pred))
        model_stat.append(np.mean(ce))
        model_stat = np.tile(model_stat,(BATCH_SIZE,1))
        state = np.concatenate((tr_labels[:, np.newaxis], tr_att[:, np.newaxis], prob, ce[:, np.newaxis], model_stat), axis=1)
        state_list.append(state)

        #####################  sample action for this batch  #####################
        rl_dict = {phrl['states_rl']: state,
               phrl['is_training_rl']: False}
        action = choose_action(sess.run(rl_prob, feed_dict=rl_dict))
        action_list.append(action)
        bool_train = list(map(bool,action))
        bool_fake = list(map(bool,1-action))
        co_images = np.concatenate((tr_images[bool_train],fa_images[bool_fake]),axis=0)
        co_labels = np.concatenate((tr_labels[bool_train],fa_labels[bool_fake]),axis=0)

        #####################  update student model with new data  #####################
        update_dict = {phs['batch_images']: co_images,
                phs['batch_labels']: co_labels,
                phs['is_training_ph']: True}
        _, ce, acc = sess.run([vgg_update, vgg_ce, vgg_acc], feed_dict=update_dict)


        if j % 100 == 0:
          print('====epoch_%d====iter_%d: loss=%.4f, train_acc=%.4f' % (i, j, np.mean(ce), acc))
          print(action, np.sum(action))

      #####################  generate terminal reward after 20 epoch of student training  #####################
      valid_acc = 0.0
      y_pred =[]
      y_label = []
      y_att = []
      for k in range(valid_iters):
        va_images, va_labels, va_att = sess.run([valid_images, valid_labels, valid_att])
        valid_dict = {phs['batch_images']: va_images,
                phs['batch_labels']: va_labels,
                phs['is_training_ph']: False}
        batch_acc, batch_pred = sess.run([vgg_acc,vgg_pred], feed_dict=valid_dict)
        valid_acc += batch_acc
        y_pred += batch_pred.tolist()
        y_label += va_labels.tolist()
        y_att += va_att.tolist()
      valid_acc = valid_acc / float(valid_iters)
      valid_eo = data.cal_eo(y_att, y_label, y_pred)
      log_string('====epoch_%d: valid_acc=%.4f, valid_eo=%.4f' % (i, valid_acc, valid_eo[-1]))
      print('eo: ',valid_eo[0],valid_eo[1])
      print('eo: ',valid_eo[2],valid_eo[3])

      if valid_acc<0.72:
        value = -5
      else:
        value = -np.log(valid_eo[-1]+1e-4)

      if valid_acc>0.7 and valid_eo[-1]<0.2:
        all_saver.save(sess,LOG_PATH+'/all.ckpt')

      #####################  update teacher model  #####################
      if i == 0:
        base = value
      else:
        base = base * 0.99 + value * 0.01
        reward = value - base
        print('reward: ',reward)

        final_state = np.reshape(state_list, (-1,11))
        final_action = np.reshape(action_list, (-1))
        final_reward = np.repeat(reward, final_state.shape[0])
        learn_dict = {phrl['states_rl']: final_state,
                phrl['actions_rl']: final_action,
                phrl['values_rl']: final_reward,
                phrl['is_training_rl']: True}
        sess.run(rl_update, feed_dict=learn_dict)