def train(tracker, params): I = np.asarray([ np.diag([1.0] * params['n_output']) for i in range(params["batch_size"]) ], dtype=np.float32) batch_size = params["batch_size"] decay_rate = 0.95 # show_every=100 deca_start = 10 # pre_best_loss=10000 with tf.Session(config=gpu_config) as sess: tf.global_variables_initializer().run() saver = tf.train.Saver() # sess.run(tracker.predict()) print 'Training model:' + params["model"] noise_std = params['noise_std'] new_noise_std = 0.0 median_result_lst = [] mean_result_lst = [] for e in range(num_epochs): if e == 2: params['lr'] = params['lr'] if e > (deca_start - 1): sess.run( tf.assign(tracker.lr, params['lr'] * (decay_rate**(e)))) else: sess.run(tf.assign(tracker.lr, params['lr'])) total_train_loss = 0 state_reset_counter_lst = [0 for i in range(batch_size)] index_train_list_s = index_train_list dic_state = ut.get_state_list(params) if params["shufle_data"] == 1 and params['reset_state'] == 1: index_train_list_s = ut.shufle_data(index_train_list) for minibatch_index in xrange(n_train_batches): is_test = 0 state_reset_counter_lst = [ s + 1 for s in state_reset_counter_lst ] (dic_state,x,y,r,f,_,state_reset_counter_lst,_)= \ th.prepare_batch(is_test,index_train_list_s, minibatch_index, batch_size, S_Train_list, dic_state, params, Y_train, X_train, R_L_Train_list,F_list_train,state_reset_counter_lst) if noise_std > 0.0: u_cnt = e * n_train_batches + minibatch_index if u_cnt in params['noise_schedule']: new_noise_std = noise_std * ( u_cnt / (params['noise_schedule'][0])) s = 'NOISE --> u_cnt %i | error %f' % (u_cnt, new_noise_std) ut.log_write(s, params) if new_noise_std > 0.0: noise = np.random.normal(0.0, new_noise_std, x.shape) x = noise + x feed = th.get_feed(tracker, params, r, x, y, I, dic_state, is_training=1) train_loss, states, _ = sess.run( [tracker.cost, tracker.states, tracker.train_op], feed) # print last_pred.shape # print states.shape for k in states.keys(): dic_state[k] = states[k] total_train_loss += train_loss # if e%5==0: # print total_train_loss pre_test = "TEST_Data" total_loss, median_result, mean_result, final_output_lst, file_lst, noise_lst = test_data( sess, params, X_test, Y_test, index_test_list, S_Test_list, R_L_Test_list, F_list_test, e, pre_test, n_test_batches) if len(full_median_result_lst) > 1: if median_result[0] < np.min(full_median_result_lst, axis=0)[0]: # ut.write_slam_est(est_file=params["est_file"],est=final_output_lst,file_names=file_lst) # ut.write_slam_est(est_file=params["noise_file"],est=noise_lst,file_names=file_lst) # save_path=params["cp_file"]+params['msg'] # saver.save(sess,save_path) print 'Writing estimations....' full_median_result_lst.append(median_result) median_result_lst.append(median_result) mean_result_lst.append(mean_result) # base_cp_path = params["cp_file"] + "/" # # lss_str = '%.5f' % total_loss # model_name = lss_str + "_" + str(e) + "_" + str(params["rn_id"]) + params["model"] + "_model.ckpt" # save_path = base_cp_path + model_name # saved_path = False # if pre_best_loss > total_loss: # pre_best_loss = total_loss # model_name = lss_str + "_" + str(e) + "_" + str(params["rn_id"]) + params["model"] + "_best_model.ckpt" # save_path = base_cp_path + model_name # saved_path = saver.save(sess, save_path) # else: # if e % 3.0 == 0: # saved_path = saver.save(sess, save_path) # if saved_path != "": # s = 'MODEL_Saved --> epoch %i | error %f path %s' % (e, total_loss, saved_path) # ut.log_write(s, params) return median_result_lst, mean_result_lst
def train(Model,params): I= np.asarray([np.diag([1.0]*params['n_output']) for i in range(params["batch_size"])],dtype=np.float32) batch_size=params["batch_size"] num_epochs=100000 decay_rate=0.9 show_every=100 deca_start=3 pre_best_loss=10000 with tf.Session() as sess:#config=gpu_config tf.global_variables_initializer().run() #saver = tf.train.Saver() # if params["model"] == "kfl_QRf": # ckpt = tf.train.get_checkpoint_state(params["mfile"]) # if ckpt and ckpt.model_checkpoint_path: # saver.restore(sess, ckpt.model_checkpoint_path) # mfile = ckpt.model_checkpoint_path # params["est_file"] = params["est_file"] + mfile.split('/')[-1].replace('.ckpt', '') + '/' # print "Loaded Model: %s" % ckpt.model_checkpoint_path # if params["model"] == "kfl_QRf": # for var in Model.tvars: # path = '/mnt/Data1/hc/tt/cp/weights/' + var.name.replace('transitionF/','') # if os.path.exists(path+'.npy'): # val=np.load(path+'.npy') # sess.run(tf.assign(var, val)) # print 'PreTrained LSTM model loaded...' # sess.run(Model.predict()) print ('Training model:'+params["model"]) noise_std = params['noise_std'] new_noise_std=0.0 for e in range(num_epochs): if e>(deca_start-1): sess.run(tf.assign(Model.lr, params['lr'] * (decay_rate ** (e)))) else: sess.run(tf.assign(Model.lr, params['lr'])) total_train_loss=0 state_reset_counter_lst=[0 for i in range(batch_size)] index_train_list_s=index_train_list dic_state = ut.get_state_list(params) # total_loss = test_data(sess, params, X_test, Y_test, index_test_list, S_Test_list, R_L_Test_list, # F_list_test, e, 'Test Check', n_test_batches) if params["shufle_data"]==1 and params['reset_state']==1: index_train_list_s = ut.shufle_data(index_train_list) for minibatch_index in xrange(n_train_batches): is_test = 0 state_reset_counter_lst=[s+1 for s in state_reset_counter_lst] (dic_state,x,y,r,f,_,state_reset_counter_lst,_)= \ th.prepare_batch(is_test,index_train_list_s, minibatch_index, batch_size, S_Train_list, dic_state, params, Y_train, X_train, R_L_Train_list,F_list_train,state_reset_counter_lst) if noise_std >0.0: u_cnt= e*n_train_batches + minibatch_index if u_cnt in params['noise_schedule']: if u_cnt==params['noise_schedule'][0]: new_noise_std=noise_std else: new_noise_std = noise_std * (u_cnt / (params['noise_schedule'][1])) s = 'NOISE --> u_cnt %i | error %f' % (u_cnt, new_noise_std) ut.log_write(s, params) if new_noise_std>0.0: noise=np.random.normal(0.0,new_noise_std,x.shape) x=noise+x feed = th.get_feed(Model, params, r, x, y, I, dic_state, is_training=1) train_loss,states,_ = sess.run([Model.cost,Model.states,Model.train_op], feed) for k in states.keys(): dic_state[k] = states[k] total_train_loss+=train_loss if (minibatch_index%show_every==0): print "Training batch loss: (%i / %i / %i) %f"%(e,minibatch_index,n_train_batches, train_loss) total_train_loss=total_train_loss/n_train_batches s='TRAIN --> epoch %i | error %f'%(e, total_train_loss) ut.log_write(s,params) pre_test = "TRAINING_Data" total_loss = test_data(sess, params, X_train, Y_train, index_train_list, S_Train_list, R_L_Train_list, F_list_train, e, pre_test, n_train_batches) pre_test="TEST_Data" total_loss= test_data(sess,params,X_test,Y_test,index_test_list,S_Test_list,R_L_Test_list,F_list_test,e, pre_test,n_test_batches) base_cp_path = params["cp_file"] + "/" lss_str = '%.5f' % total_loss model_name = lss_str + "_" + str(e) + "_" + str(params["rn_id"]) + params["model"] + "_model.ckpt" save_path = base_cp_path + model_name saved_path = False if pre_best_loss > total_loss: pre_best_loss = total_loss model_name = lss_str + "_" + str(e) + "_" + str(params["rn_id"]) + params["model"] + "_best_model.ckpt" save_path = base_cp_path + model_name saved_path = saver.save(sess, save_path) else: if e % 3.0 == 0: saved_path = saver.save(sess, save_path) if saved_path != "": s = 'MODEL_Saved --> epoch %i | error %f path %s' % (e, total_loss, saved_path) ut.log_write(s, params)
def train(): batch_size = params["batch_size"] num_epochs = 1000 decay_rate = 0.5 show_every = 100 deca_start = 2 with tf.Session(config=gpu_config) as sess: tf.global_variables_initializer().run() # sess.run(tracker.predict()) print 'Training Noise KLSTM' noise_std = params['noise_std'] new_noise_std = 0.0 for e in range(num_epochs): if e > (deca_start - 1): sess.run( tf.assign( tracker.lr, params['lr'] * (decay_rate**((e - deca_start) / 2)))) else: sess.run(tf.assign(tracker.lr, params['lr'])) total_train_loss = 0 LStateList_F_t = ut.get_zero_state(params) LStateList_F_pre = ut.get_zero_state(params) LStateList_K_t = ut.get_zero_state(params, t='K') LStateList_K_pre = ut.get_zero_state(params, t='K') state_reset_counter_lst = [0 for i in range(batch_size)] index_train_list_s = index_train_list if params["shufle_data"] == 1 and params['reset_state'] == 1: index_train_list_s = ut.shufle_data(index_train_list) for minibatch_index in xrange(n_train_batches): state_reset_counter_lst = [ s + 1 for s in state_reset_counter_lst ] (LStateList_F_pre,LStateList_K_pre,_,x,y,r,f,state_reset_counter_lst)=\ dut.prepare_kfl_QRFf_batch(index_train_list_s, minibatch_index, batch_size, S_Train_list, LStateList_F_t, LStateList_F_pre, LStateList_K_t, LStateList_K_pre, None, None, params, Y_train, X_train, R_L_Train_list,F_list_train,state_reset_counter_lst) if noise_std > 0.0: u_cnt = e * n_train_batches + minibatch_index if u_cnt in params['noise_schedule']: new_noise_std = noise_std * ( u_cnt / (params['noise_schedule'][0])) s = 'NOISE --> u_cnt %i | error %f' % (u_cnt, new_noise_std) ut.log_write(s, params) if new_noise_std > 0.0: noise = np.random.normal(0.0, new_noise_std, x.shape) x = noise + x gt = y mes = x feed = { tracker._z: mes, tracker.target_data: gt, tracker.repeat_data: r, tracker.initial_state: LStateList_F_pre, tracker.initial_state_K: LStateList_K_pre, tracker.output_keep_prob: params['rnn_keep_prob'] } # feed = {tracker._z: mes, tracker.target_data: gt, tracker.initial_state: LStateList_F_pre # , tracker._P_inp: P, tracker._I: I} train_loss,LStateList_F_t,LStateList_K_t,_ = \ sess.run([tracker.cost,tracker.final_state_F,tracker.final_state_Q, tracker.train_op], feed) tmp_lst = [] for item in LStateList_F_t: tmp_lst.append(item.c) tmp_lst.append(item.h) LStateList_F_t = tmp_lst tmp_lst = [] for item in LStateList_K_t: tmp_lst.append(item.c) tmp_lst.append(item.h) LStateList_K_t = tmp_lst total_train_loss += train_loss if (minibatch_index % show_every == 0): print "Training batch loss: (%i / %i / %i) %f" % ( e, minibatch_index, n_train_batches, train_loss) total_train_loss = total_train_loss / n_train_batches s = 'TRAIN --> epoch %i | error %f' % (e, total_train_loss) ut.log_write(s, params) pre_test = "TEST_Data" total_loss = test_data(sess, X_test, Y_test, index_test_list, S_Test_list, R_L_Test_list, F_list_test, e, pre_test, n_test_batches)