Ejemplo n.º 1
0
def train_rnn(params):
   rng = RandomStreams(seed=1234)
   (X_train,Y_train,S_Train_list,F_list_train,G_list_train,X_test,Y_test,S_Test_list,F_list_test,G_list_test)=du.load_pose(params)
   params["len_train"]=Y_train.shape[0]*Y_train.shape[1]
   params["len_test"]=Y_test.shape[0]*Y_test.shape[1]
   u.start_log(params)
   index_train_list,S_Train_list=du.get_seq_indexes(params,S_Train_list)
   index_test_list,S_Test_list=du.get_seq_indexes(params,S_Test_list)
   batch_size=params['batch_size']
   n_train_batches = len(index_train_list)
   n_train_batches /= batch_size

   n_test_batches = len(index_test_list)
   n_test_batches /= batch_size

   nb_epochs=params['n_epochs']

   print("Batch size: %i, train batch size: %i, test batch size: %i"%(batch_size,n_train_batches,n_test_batches))
   u.log_write("Model build started",params)
   if params['run_mode']==1:
      model= model_provider.get_model_pretrained(params,rng)
      u.log_write("Pretrained loaded: %s"%(params['mfile']),params)
   else:
     model= model_provider.get_model(params,rng)
   u.log_write("Number of parameters: %s"%(model.n_param),params)
   train_errors = np.ndarray(nb_epochs)
   u.log_write("Training started",params)
   val_counter=0
   best_loss=1000
   for epoch_counter in range(nb_epochs):
      batch_loss = 0.
      LStateList_t=[np.zeros(shape=(batch_size,params['n_hidden']), dtype=dtype) for i in range(params['nlayer']*2)] # initial hidden state
      LStateList_pre=[np.zeros(shape=(batch_size,params['n_hidden']), dtype=dtype) for i in range(params['nlayer']*2)] # initial hidden state
      state_reset_counter_lst=[0 for i in range(batch_size)]
      is_train=1
      for minibatch_index in range(n_train_batches):
          state_reset_counter_lst=[s+1 for s in state_reset_counter_lst]
          (LStateList_b,x,y,state_reset_counter_lst)=du.prepare_lstm_batch(index_train_list, minibatch_index, batch_size, S_Train_list,LStateList_t,LStateList_pre, F_list_train, params, Y_train, X_train,state_reset_counter_lst)
          LStateList_pre=LStateList_b
          args=(x, y,is_train)+tuple(LStateList_b)
          result= model.train(*args)
          loss=result[0]
          LStateList_t=result[1:len(result)]

          batch_loss += loss
      if params['shufle_data']==1:
         X_train,Y_train=du.shuffle_in_unison_inplace(X_train,Y_train)
      train_errors[epoch_counter] = batch_loss
      batch_loss/=n_train_batches
      s='TRAIN--> epoch %i | error %f'%(epoch_counter, batch_loss)
      u.log_write(s,params)
      if(epoch_counter%1==0):
          is_train=0
          print("Model testing")
          state_reset_counter=0
          batch_loss3d = []
          LStateList_t=[np.zeros(shape=(batch_size,params['n_hidden']), dtype=dtype) for i in range(params['nlayer']*2)] # initial hidden state
          LStateList_pre=[np.zeros(shape=(batch_size,params['n_hidden']), dtype=dtype) for i in range(params['nlayer']*2)] # initial hidden sta
          state_reset_counter_lst=[0 for i in range(batch_size)]
          for minibatch_index in range(n_test_batches):
             state_reset_counter_lst=[s+1 for s in state_reset_counter_lst]
             (LStateList_b,x,y,state_reset_counter_lst)=du.prepare_lstm_batch(index_test_list, minibatch_index, batch_size, S_Test_list, LStateList_t,LStateList_pre, F_list_test, params, Y_test, X_test,state_reset_counter_lst)
             LStateList_pre=LStateList_b
             args=(x,is_train)+tuple(LStateList_b)
             result = model.predictions(*args)
             pred=result[0]
             LStateList_t=result[1:len(result)]
             loss3d =u.get_loss(params,y,pred)
             batch_loss3d.append(loss3d)
          batch_loss3d=np.nanmean(batch_loss3d)
          if(batch_loss3d<best_loss):
             best_loss=batch_loss3d
             ext=str(epoch_counter)+"_"+str(batch_loss3d)+"_best.p"
             u.write_params(model.params,params,ext)
          else:
              ext=str(val_counter%2)+".p"
              u.write_params(model.params,params,ext)

          val_counter+=1#0.08
          s ='VAL--> epoch %i | error %f, %f'%(val_counter,batch_loss3d,n_test_batches)
          u.log_write(s,params)
Ejemplo n.º 2
0
import numpy as np
import tensorflow as tf

from helper import config
from helper import dt_utils as dut
from helper import utils as ut

# from model_runner.rnn_lstm import  Model
from model_runner.lstm.rnn_lstm_2layer import Model

params = config.get_params()
params["model"] = "lstmv2_2layer"
params = config.update_params(params)
(F_names_training, S_Train_list, F_names_test,
 S_Test_list) = dut.prepare_training_set_fnames(params)
index_train_list, S_Train_list = dut.get_seq_indexes(params, S_Train_list)
index_test_list, S_Test_list = dut.get_seq_indexes(params, S_Test_list)

batch_size = params['batch_size']
n_train_batches = len(index_train_list)
n_train_batches /= batch_size

n_test_batches = len(index_test_list)
n_test_batches /= batch_size

params['training_size'] = len(F_names_training) * params['seq_length']
params['test_size'] = len(F_names_test) * params['seq_length']
ut.start_log(params)
ut.log_write("Model training started", params)
# summary_writer = tf.train.SummaryWriter(params["sm"])
Ejemplo n.º 3
0
    print("Model loaded:%s" % params["model"])
    loss = 0.
    total_cnt = 0.
    test_write = True
    for action in lst_action:
        params["action"] = action
        if test_write == True:
            X, Y, F_list, G_list, S_list, R_L_list = dut.get_action_dataset(
                params, X_test, Y_test, F_list_test, G_list_test, S_Test_list,
                R_L_Test_list)
        else:
            X, Y, F_list, G_list, S_list, R_L_list = dut.get_action_dataset(
                params, X_train, Y_train, F_list_train, G_list_train,
                S_Train_list, R_L_Train_list)

        index_list, S_list = dut.get_seq_indexes(params, S_list)

        batch_size = params['batch_size']
        n_batches = len(index_list)
        n_batches /= batch_size

        LStateList_t = ut.get_zero_state(params)
        LStateList_pre = ut.get_zero_state(params)
        state_reset_counter_lst = [0 for i in range(batch_size)]
        total_loss = 0.0
        total_n_count = 0.0
        for minibatch_index in xrange(n_batches):
            state_reset_counter_lst = [s + 1 for s in state_reset_counter_lst]
            # (LStateList_b,x,y,r,f,state_reset_counter_lst)=dut.prepare_lstm_batch(index_list, minibatch_index, batch_size, S_list,LStateList_t,LStateList_pre, params, Y, X,R_L_list,F_list,state_reset_counter_lst)
            (LStateList_b,x,y,r,f,state_reset_counter_lst)=\
            dut.prepare_lstm_batch(index_list, minibatch_index, batch_size,