def predict_on_fullimage(test_set_x,params): model= model_provider.get_model_pretrained(params) # learning parameters batch_size =params["batch_size"] n_test_batches = len(test_set_x) ash=n_test_batches%batch_size if(ash>0): test_set_x=np.vstack((test_set_x,np.tile(test_set_x[-1],(batch_size-ash,1)))) n_test_batches = len(test_set_x) n_test_batches /= batch_size y_pred=[] print("Number of parameters: %s"%(model.count_params())) print "Prediction on test images" patch_loc=u.get_patch_loc(params) #we are not using this for image map_index=0 for i in xrange(n_test_batches): Fx = test_set_x[i * batch_size: (i + 1) * batch_size] argu= [(params,"F", Fx,patch_loc,map_index),(params,"S", Fx,patch_loc,map_index)] results = du.asyn_load_batch_images(argu) data_Fx = results[0] data_Sx = results[1] if(params["model_type"]==4): data=data_Sx-data_Fx res =model.predict(data) else: res=model.predict([data_Fx, data_Sx]) if(len(y_pred)==0): y_pred= res else: y_pred=np.concatenate((y_pred,res)) if(ash>0): y_pred= y_pred[0:-(batch_size-ash)] return y_pred
def predict_on_multi_input(test_set_x,params): model= model_provider.get_model_pretrained(params) # learning parameters batch_size =params["batch_size"] n_test_batches = len(test_set_x) ash=n_test_batches%batch_size if(ash>0): test_set_x=np.vstack((test_set_x,np.tile(test_set_x[-1],(batch_size-ash,1)))) n_test_batches = len(test_set_x) n_test_batches /= batch_size y_pred=[] pred_mat={} print "Prediction on test images" n_patch=params["n_patch"] n_repeat=params["n_repeat"] map_list=range(batch_size*n_repeat) n=n_patch*n_repeat for i in xrange(n_test_batches): map_index=map_list[i]%n_repeat Fx = test_set_x[i * batch_size: (i + 1) * batch_size] pred=np.zeros((batch_size,params['n_output'])) pred_mat[0]=np.zeros((n,params['n_output'])) for k in range(batch_size): pred_mat[k+i]=np.zeros((n,params['n_output'])) for patch_index in xrange(n): patch_loc=u.get_patch_loc(params) argu= [(params,"F", Fx,patch_loc,map_index),(params,"S", Fx,patch_loc,map_index)] results = du.asyn_load_batch_images(argu) data_Fx = results[0] data_Sx = results[1] prd=model.predict([data_Fx, data_Sx]) for j in range(len(prd)): pred_mat[i+j][patch_index]=prd[j] pred=np.add(pred, prd) pred/=n if(len(y_pred)==0): y_pred= pred else: y_pred=np.concatenate((y_pred,pred)) pd.plot_patch(pred_mat) if(ash>0): y_pred= y_pred[0:-(batch_size-ash)] return y_pred
def train_model(params): rn_id=params["rn_id"] im_type=params["im_type"] batch_size =params["batch_size"] n_epochs =params["n_epochs"] datasets = data_loader.load_data(params) utils.start_log(datasets,params) X_train, y_train,overlaps_train = datasets[0] X_val, y_val,overlaps_val = datasets[1] X_test, y_test,overlaps_test = datasets[2] # compute number of minibatches for training, validation and testing n_train_batches = len(X_train) n_valid_batches = len(X_val) n_test_batches = len(X_test) n_train_batches /= batch_size n_valid_batches /= batch_size n_test_batches /= batch_size y_val_mean=np.mean(y_val) y_val_abs_mean=np.mean(np.abs(y_val)) utils.log_write("Model build started",params) model= model_provider.get_model(params) utils.log_write("Number of parameters: %s"%(model.count_params()),params) run_mode=params["run_mode"] utils.log_write("Model build ended",params) utils.log_write("Training started",params) best_validation_loss=np.inf epoch_counter = 0 n_patch=params["n_patch"] n_repeat=params["n_repeat"]#instead of extracting many batches for each epoch, we are repeating epoch since we are ensuring that output changes for each patch while (epoch_counter < n_epochs): epoch_counter = epoch_counter + 1 print("Training model...") map_list=range(n_train_batches*n_repeat) random.shuffle(map_list) for index in xrange(n_train_batches*n_repeat): minibatch_index=index%n_train_batches map_index=map_list[index]%n_repeat #We are shuffling data at each batch, we are shufling here because we already finished one epoch just repeating for the extract different batch if(index>0 and minibatch_index==0):#we are checking weather we finish all dataset ext=params["model_file"]+params["model"]+"_"+im_type+"_m_"+str(index%5)+".hdf5" model.save_weights(ext, overwrite=True) X_train,y_train=dt_utils.shuffle_in_unison_inplace(X_train,y_train) iter = (epoch_counter - 1) * n_train_batches + index if iter % 100 == 0: print 'training @ iter = ', iter batch_loss=0 Fx = X_train[minibatch_index * batch_size: (minibatch_index + 1) * batch_size] data_y = y_train[minibatch_index * batch_size: (minibatch_index + 1) * batch_size] for patch_index in xrange(n_patch): patch_loc=utils.get_patch_loc(params) argu= [(params,"F", Fx,patch_loc,map_index),(params,"S", Fx,patch_loc,map_index)] results = dt_utils.asyn_load_batch_images(argu) data_Fx = results[0] data_Sx = results[1] if(params["model_type"]==4): data=data_Sx-data_Fx loss =model.train_on_batch(data, data_y) else: loss =model.train_on_batch([data_Fx, data_Sx], data_y) if isinstance(loss,list): batch_loss+=loss[0] else: batch_loss+=loss batch_loss/=n_patch s='TRAIN--> epoch %i | batch_index %i/%i | error %f'%(epoch_counter, index + 1, n_train_batches*n_repeat, batch_loss) utils.log_write(s,params) if(run_mode==1): break #we are shufling for to be sure X_train,y_train=dt_utils.shuffle_in_unison_inplace(X_train,y_train) ext=params["model_file"]+params["model"]+"_"+im_type+"_e_"+str(rn_id)+"_"+str(epoch_counter % 10)+".hdf5" model.save_weights(ext, overwrite=True) if params['validate']==0: print("Validation skipped...") if(run_mode==1): break continue print("Validating model...") this_validation_loss = 0 map_list=range(n_valid_batches*n_repeat) random.shuffle(map_list) for index in xrange(n_valid_batches*n_repeat): i = index%n_valid_batches map_index=map_list[index]%n_repeat epoch_loss=0 Fx = X_val[i * batch_size: (i + 1) * batch_size] data_y = y_val[i * batch_size: (i + 1) * batch_size] for patch_index in xrange(n_patch): patch_loc=utils.get_patch_loc(params) argu= [(params,"F", Fx,patch_loc,map_index),(params,"S", Fx,patch_loc,map_index)] results = dt_utils.asyn_load_batch_images(argu) data_Fx = results[0] data_Sx = results[1] if(params["model_type"]==4): data=data_Sx-data_Fx loss =model.test_on_batch(data, data_y) else: loss= model.test_on_batch([data_Fx, data_Sx],data_y) if isinstance(loss,list): epoch_loss+=loss[0] else: epoch_loss+=loss epoch_loss/=n_patch this_validation_loss +=epoch_loss if(run_mode==1): break this_validation_loss /= (n_valid_batches*n_repeat) s ='VAL--> epoch %i | error %f | data mean/abs %f/%f'%(epoch_counter, this_validation_loss,y_val_mean,y_val_abs_mean) utils.log_write(s,params) if this_validation_loss < best_validation_loss: best_validation_loss = this_validation_loss ext=params["model_file"]+params["model"]+"_"+im_type+"_"+"_best_"+str(rn_id)+"_"+str(epoch_counter)+".hdf5" model.save_weights(ext, overwrite=True) if(run_mode==1): break utils.log_write("Training ended",params)