def train(): batch_images = tf.placeholder(tf.float32, [None, 128, 128, 3]) batch_labels = tf.placeholder(tf.int32, [ None, ]) is_training_ph = tf.placeholder(tf.bool) lr_ph = tf.placeholder(tf.float32) states_rl = tf.placeholder(tf.float32, [None, 2]) actions_rl = tf.placeholder(tf.int32, [ None, ]) values_rl = tf.placeholder(tf.float32, [ None, ]) is_training_rl = tf.placeholder(tf.bool) lr_rl = tf.placeholder(tf.float32) states_c = tf.placeholder(tf.float32, [None, 7]) values_c = tf.placeholder(tf.float32, [ None, ]) is_training_c = tf.placeholder(tf.bool) lr_c = tf.placeholder(tf.float32) phs = { 'batch_images': batch_images, 'batch_labels': batch_labels, 'is_training_ph': is_training_ph, 'lr_ph': lr_ph } phrl = { 'states_rl': states_rl, 'actions_rl': actions_rl, 'values_rl': values_rl, 'is_training_rl': is_training_rl, 'lr_rl': lr_rl } phc = { 'states_c': states_c, 'values_c': values_c, 'is_training_c': is_training_c, 'lr_c': lr_c } with tf.Session() as sess: vgg_loss, vgg_acc, vgg_ce, vgg_prob, vgg_update, vgg_pred, vgg_vars = vgg_graph( sess, phs) rl_loss, rl_prob, rl_update, rl_vars = rl_graph(sess, phrl) c_loss, c_value, c_update, c_vars = c_graph(sess, phc) vgg_init = tf.variables_initializer(var_list=vgg_vars) saver = tf.train.Saver(vgg_vars) all_saver = tf.train.Saver() init = tf.global_variables_initializer() sess.run(init) # for epoch in range(4): # for t in range(train_iters): # if t % 50==0: print("pretrain:", t) # tr_images, tr_labels = sess.run([train_images,train_labels]) # pre_dict = {phs['batch_images']: tr_images, # phs['batch_labels']: tr_labels, # phs['is_training_ph']: True} # sess.run(vgg_update, feed_dict=pre_dict) # saver.save(sess,LOG_PATH+'/vgg.ckpt') # valid_acc = 0.0 # y_pred =[] # y_label = [] # y_att = [] # for k in range(valid_iters): # va_images, va_labels, va_att = sess.run([valid_images, valid_labels, valid_att]) # valid_dict = {phs['batch_images']: va_images, # phs['batch_labels']: va_labels, # phs['is_training_ph']: False} # batch_acc, batch_pred = sess.run([vgg_acc,vgg_pred], feed_dict=valid_dict) # valid_acc += batch_acc # y_pred += batch_pred.tolist() # y_label += va_labels.tolist() # y_att += va_att.tolist() # valid_acc = valid_acc / float(valid_iters) # valid_eo = data.cal_eo(y_att, y_label, y_pred) # log_string('====pretrain: valid_acc=%.4f, valid_eo=%.4f' % (valid_acc, valid_eo[-1])) # print(valid_eo) va_images, va_labels, va_att = sess.run( [valid_images, valid_labels, valid_att]) for i in range(N_EPISODE): sess.run(vgg_init) # saver.restore(sess,LOG_PATH+'/vgg.ckpt') train_loss = [] for j in range(train_iters * 20): tr_images, tr_labels, tr_att = sess.run( [train_images, train_labels, train_att]) fa_images, fa_labels, fa_att = sess.run( [fake_images, fake_labels, fake_att]) train_dict = { phs['batch_images']: tr_images, phs['batch_labels']: tr_labels, phs['is_training_ph']: False } ce, acc, prob, pred = sess.run( [vgg_ce, vgg_acc, vgg_prob, vgg_pred], feed_dict=train_dict) ce = np.clip(ce, 0, 10) / 10.0 train_loss.append(np.mean(ce)) model_stat = list(data.cal_eo(tr_att, tr_labels, pred)) #shape [5,] model_stat.append(np.mean(ce)) model_stat.append(j / (train_iters * 20)) # model_stat.append(np.mean(train_loss)) c_state = np.array(model_stat)[np.newaxis, :] # model_stat = np.tile(model_stat,(BATCH_SIZE,1)) state = np.concatenate( (tr_labels[:, np.newaxis], tr_att[:, np.newaxis]), axis=1) rl_dict = { phrl['states_rl']: state, phrl['is_training_rl']: False } action = choose_action(sess.run(rl_prob, feed_dict=rl_dict)) c_dict = { phc['states_c']: c_state, phc['is_training_c']: False } base = sess.run(c_value, feed_dict=c_dict) bool_train = list(map(bool, action)) bool_fake = list(map(bool, 1 - action)) co_images = np.concatenate( (tr_images[bool_train], fa_images[bool_fake]), axis=0) co_labels = np.concatenate( (tr_labels[bool_train], fa_labels[bool_fake]), axis=0) update_dict = { phs['batch_images']: co_images, phs['batch_labels']: co_labels, phs['is_training_ph']: True } _, ce, acc = sess.run([vgg_update, vgg_ce, vgg_acc], feed_dict=update_dict) valid_dict = { phs['batch_images']: va_images, phs['batch_labels']: va_labels, phs['is_training_ph']: False } valid_acc, y_pred = sess.run([vgg_acc, vgg_pred], feed_dict=valid_dict) valid_eo = data.cal_eo(va_att, va_labels, y_pred) if valid_eo[-1] <= 0.05: value = -2 else: value = -np.log(valid_eo[-1]) reward = value - base[0] c_dict = { phc['states_c']: c_state, phc['values_c']: [value], phc['is_training_c']: True } _, cri_loss = sess.run([c_update, c_loss], feed_dict=c_dict) final_reward = np.repeat(reward, BATCH_SIZE) learn_dict = { phrl['states_rl']: state, phrl['actions_rl']: action, phrl['values_rl']: final_reward, phrl['is_training_rl']: True } sess.run(rl_update, feed_dict=learn_dict) if j % 10 == 0: log_string( '====epoch_%d====iter_%d: student_loss=%.4f, train_acc=%.4f' % (i, j, np.mean(ce), acc)) log_string( '===============: critic_loss=%.4f, reward=%.4f, valid_acc=%.4f, valid_eo=%.4f' % (cri_loss, reward, valid_acc, valid_eo[-1])) print('eo: ', valid_eo[0], valid_eo[1]) print('eo: ', valid_eo[2], valid_eo[3]) print(action, np.sum(action)) all_saver.save(sess, LOG_PATH + '/all.ckpt')
def train(): batch_images = tf.placeholder(tf.float32,[None,128,128,3]) batch_labels = tf.placeholder(tf.int32,[None,]) is_training_ph = tf.placeholder(tf.bool) lr_ph = tf.placeholder(tf.float32) states_rl = tf.placeholder(tf.float32,[None,11]) actions_rl = tf.placeholder(tf.int32,[None,]) values_rl = tf.placeholder(tf.float32,[None,]) is_training_rl = tf.placeholder(tf.bool) lr_rl = tf.placeholder(tf.float32) phs = {'batch_images': batch_images, 'batch_labels': batch_labels, 'is_training_ph': is_training_ph, 'lr_ph': lr_ph} phrl = {'states_rl': states_rl, 'actions_rl': actions_rl, 'values_rl': values_rl, 'is_training_rl': is_training_rl, 'lr_rl': lr_rl} with tf.Session() as sess: vgg_loss, vgg_acc, vgg_ce, vgg_prob, vgg_update, vgg_pred, vgg_vars = vgg_graph(sess, phs) rl_loss, rl_prob, rl_update, rl_vars = rl_graph(sess, phrl) vgg_init = tf.variables_initializer(var_list=vgg_vars) saver = tf.train.Saver(vgg_vars) all_saver = tf.train.Saver() init = tf.global_variables_initializer() sess.run(init) for i in range(N_EPISODE): # sess.run(vgg_init) all_saver.restore(sess,LOG_PATH+'/all.ckpt') saver.restore(sess,LOG_PATH+'/vgg.ckpt') # state_list = [] # action_list = [] # reward_list = [] for j in range(train_iters*20): tr_images, tr_labels, tr_att = sess.run([train_images,train_labels, train_att]) fa_images, fa_labels, fa_att = sess.run([fake_images,fake_labels, fake_att]) train_dict = {phs['batch_images']: tr_images, phs['batch_labels']: tr_labels, phs['is_training_ph']: False} ce, acc, prob, pred = sess.run([vgg_ce, vgg_acc, vgg_prob, vgg_pred], feed_dict=train_dict) ce = np.clip(ce, 0, 10)/10.0 model_stat = list(data.cal_eo(tr_att, tr_labels, pred)) model_stat.append(np.mean(ce)) model_stat = np.tile(model_stat,(BATCH_SIZE,1)) state = np.concatenate((tr_labels[:, np.newaxis], tr_att[:, np.newaxis], prob, ce[:, np.newaxis], model_stat), axis=1) rl_dict = {phrl['states_rl']: state, phrl['is_training_rl']: False} action = choose_action(sess.run(rl_prob, feed_dict=rl_dict)) bool_train = list(map(bool,action)) bool_fake = list(map(bool,1-action)) co_images = np.concatenate((tr_images[bool_train],fa_images[bool_fake]),axis=0) co_labels = np.concatenate((tr_labels[bool_train],fa_labels[bool_fake]),axis=0) update_dict = {phs['batch_images']: co_images, phs['batch_labels']: co_labels, phs['is_training_ph']: True} _, ce, acc = sess.run([vgg_update, vgg_ce, vgg_acc], feed_dict=update_dict) if j % 100 == 0: print('====epoch_%d====iter_%d: loss=%.4f, train_acc=%.4f' % (i, j, np.mean(ce), acc)) print(action, np.sum(action)) valid_acc = 0.0 y_pred =[] y_label = [] y_att = [] for k in range(valid_iters): va_images, va_labels, va_att = sess.run([valid_images, valid_labels, valid_att]) valid_dict = {phs['batch_images']: va_images, phs['batch_labels']: va_labels, phs['is_training_ph']: False} batch_acc, batch_pred = sess.run([vgg_acc,vgg_pred], feed_dict=valid_dict) valid_acc += batch_acc y_pred += batch_pred.tolist() y_label += va_labels.tolist() y_att += va_att.tolist() valid_acc = valid_acc / float(valid_iters) valid_eo = data.cal_eo(y_att, y_label, y_pred) log_string('====epoch_%d: valid_acc=%.4f, valid_eo=%.4f' % (i, valid_acc, valid_eo[-1])) print('eo: ',valid_eo[0],valid_eo[1]) print('eo: ',valid_eo[2],valid_eo[3])
lr_ph: lr, is_training: True }) if j % 50 == 0: print('====epoch_%d====iter_%d: loss=%.4f, train_acc=%.4f' % (i, j, loss, acc)) valid_acc = 0.0 y_pred = [] y_label = [] y_att = [] for k in range(valid_iters): va_images, va_labels, va_att = sess.run( [valid_images, valid_labels, valid_att]) batch_acc, batch_pred = sess.run( [acc_op, Y_pred], { batch_images: va_images, batch_labels: va_labels, is_training: False }) valid_acc += batch_acc y_pred += batch_pred.tolist() y_label += va_labels.tolist() y_att += va_att.tolist() valid_acc = valid_acc / float(valid_iters) valid_eo = data.cal_eo(y_att, y_label, y_pred) print('====epoch_%d: valid_acc=%.4f, valid_eo=%.4f' % (i, valid_acc, valid_eo[-1])) print('eo: ', valid_eo[0], valid_eo[1]) print('eo: ', valid_eo[2], valid_eo[3])
def train(): batch_images = tf.placeholder(tf.float32,[None,128,128,3]) batch_labels = tf.placeholder(tf.int32,[None,]) is_training_ph = tf.placeholder(tf.bool) lr_ph = tf.placeholder(tf.float32) states_rl = tf.placeholder(tf.float32,[None,11]) actions_rl = tf.placeholder(tf.int32,[None,]) values_rl = tf.placeholder(tf.float32,[None,]) is_training_rl = tf.placeholder(tf.bool) lr_rl = tf.placeholder(tf.float32) phs = {'batch_images': batch_images, 'batch_labels': batch_labels, 'is_training_ph': is_training_ph, 'lr_ph': lr_ph} phrl = {'states_rl': states_rl, 'actions_rl': actions_rl, 'values_rl': values_rl, 'is_training_rl': is_training_rl, 'lr_rl': lr_rl} with tf.Session() as sess: # tf.reset_default_graph() vgg_loss, vgg_acc, vgg_ce, vgg_prob, vgg_update, vgg_pred, vgg_vars = vgg_graph(sess, phs) rl_loss, rl_prob, rl_update, rl_vars = rl_graph(sess, phrl) vgg_init = tf.variables_initializer(var_list=vgg_vars) saver = tf.train.Saver(vgg_vars) all_saver = tf.train.Saver() init = tf.global_variables_initializer() sess.run(init) ##################### pre-train student model ##################### for epoch in range(4): for t in range(train_iters): if t % 50==0: print("pretrain:", t) tr_images, tr_labels = sess.run([train_images,train_labels]) pre_dict = {phs['batch_images']: tr_images, phs['batch_labels']: tr_labels, phs['is_training_ph']: True} sess.run(vgg_update, feed_dict=pre_dict) saver.save(sess,LOG_PATH+'/vgg.ckpt') valid_acc = 0.0 y_pred =[] y_label = [] y_att = [] for k in range(valid_iters): va_images, va_labels, va_att = sess.run([valid_images, valid_labels, valid_att]) valid_dict = {phs['batch_images']: va_images, phs['batch_labels']: va_labels, phs['is_training_ph']: False} batch_acc, batch_pred = sess.run([vgg_acc,vgg_pred], feed_dict=valid_dict) valid_acc += batch_acc y_pred += batch_pred.tolist() y_label += va_labels.tolist() y_att += va_att.tolist() valid_acc = valid_acc / float(valid_iters) valid_eo = data.cal_eo(y_att, y_label, y_pred) log_string('====pretrain: valid_acc=%.4f, valid_eo=%.4f' % (valid_acc, valid_eo[-1])) print(valid_eo) ##################### train teacher model ##################### for i in range(N_EPISODE): # sess.run(vgg_init) saver.restore(sess,LOG_PATH+'/vgg.ckpt') state_list = [] action_list = [] reward_list = [] for j in range(train_iters*20): tr_images, tr_labels, tr_att = sess.run([train_images,train_labels, train_att]) fa_images, fa_labels, fa_att = sess.run([fake_images,fake_labels, fake_att]) # va_images, va_labels, va_att = sess.run([valid_images,valid_labels, valid_att]) ##################### generate state info from student model & data ##################### train_dict = {phs['batch_images']: tr_images, phs['batch_labels']: tr_labels, phs['is_training_ph']: False} ce, acc, prob, pred = sess.run([vgg_ce, vgg_acc, vgg_prob, vgg_pred], feed_dict=train_dict) ce = np.clip(ce, 0, 10)/10.0 model_stat = list(data.cal_eo(tr_att, tr_labels, pred)) model_stat.append(np.mean(ce)) model_stat = np.tile(model_stat,(BATCH_SIZE,1)) state = np.concatenate((tr_labels[:, np.newaxis], tr_att[:, np.newaxis], prob, ce[:, np.newaxis], model_stat), axis=1) state_list.append(state) ##################### sample action for this batch ##################### rl_dict = {phrl['states_rl']: state, phrl['is_training_rl']: False} action = choose_action(sess.run(rl_prob, feed_dict=rl_dict)) action_list.append(action) bool_train = list(map(bool,action)) bool_fake = list(map(bool,1-action)) co_images = np.concatenate((tr_images[bool_train],fa_images[bool_fake]),axis=0) co_labels = np.concatenate((tr_labels[bool_train],fa_labels[bool_fake]),axis=0) ##################### update student model with new data ##################### update_dict = {phs['batch_images']: co_images, phs['batch_labels']: co_labels, phs['is_training_ph']: True} _, ce, acc = sess.run([vgg_update, vgg_ce, vgg_acc], feed_dict=update_dict) if j % 100 == 0: print('====epoch_%d====iter_%d: loss=%.4f, train_acc=%.4f' % (i, j, np.mean(ce), acc)) print(action, np.sum(action)) ##################### generate terminal reward after 20 epoch of student training ##################### valid_acc = 0.0 y_pred =[] y_label = [] y_att = [] for k in range(valid_iters): va_images, va_labels, va_att = sess.run([valid_images, valid_labels, valid_att]) valid_dict = {phs['batch_images']: va_images, phs['batch_labels']: va_labels, phs['is_training_ph']: False} batch_acc, batch_pred = sess.run([vgg_acc,vgg_pred], feed_dict=valid_dict) valid_acc += batch_acc y_pred += batch_pred.tolist() y_label += va_labels.tolist() y_att += va_att.tolist() valid_acc = valid_acc / float(valid_iters) valid_eo = data.cal_eo(y_att, y_label, y_pred) log_string('====epoch_%d: valid_acc=%.4f, valid_eo=%.4f' % (i, valid_acc, valid_eo[-1])) print('eo: ',valid_eo[0],valid_eo[1]) print('eo: ',valid_eo[2],valid_eo[3]) if valid_acc<0.72: value = -5 else: value = -np.log(valid_eo[-1]+1e-4) if valid_acc>0.7 and valid_eo[-1]<0.2: all_saver.save(sess,LOG_PATH+'/all.ckpt') ##################### update teacher model ##################### if i == 0: base = value else: base = base * 0.99 + value * 0.01 reward = value - base print('reward: ',reward) final_state = np.reshape(state_list, (-1,11)) final_action = np.reshape(action_list, (-1)) final_reward = np.repeat(reward, final_state.shape[0]) learn_dict = {phrl['states_rl']: final_state, phrl['actions_rl']: final_action, phrl['values_rl']: final_reward, phrl['is_training_rl']: True} sess.run(rl_update, feed_dict=learn_dict)