def run_hyperparam(modelStr, numEpochs): # Hyperparameter search import itertools #--------------------------------------------- param_key = ['lr'] lr = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] #--------------------------------------------- #----set variables to emulate run_model()----- runMode = 'train' ckptDir = 'hyperparam_tmpckpt' if numEpochs == NUM_EPOCHS: numEpochs = NUM_EPOCHS_HYPERPARAM #--------------------------------------------- ps = [lr] params = list(itertools.product(*ps)) logDir, logName = createLog('_hyperparam') count = 0 best_val = 0 best_param = None for param in params: # build a new computational graph tf.reset_default_graph() input_sz = [IMG_DIM, IMG_DIM, 1] output_sz = [IMG_DIM, IMG_DIM, 3] curModel = Unet(input_sz, output_sz, verbose=False) curModel.create_model() curModel.metrics() count += 1 printSeparator('Running #%d of %d runs...' % (count, len(params))) print(getParamStr(param_key, param)) with tf.Session(config=GPU_CONFIG) as sess: # train the network dataset_filenames = getDataFileNames( TRAIN_DATA, excludeFnames=['.filepart', 'test']) for i in range(numEpochs): random.shuffle(dataset_filenames) train_loss = run_one_epoch(sess, curModel, dataset_filenames, modelStr, is_training=True) print('#%d training loss: %f' % (i, train_loss)) # run the trained network on the validation set dataset_filenames = getDataFileNames(VALIDATION_DATA) val_loss = run_one_epoch(sess, curModel, dataset_filenames, modelStr, is_training=False) logToFile( logName, 'train loss: %f, val loss: %f\n' % (train_loss, val_loss)) if best_val < val_loss: best_val = val_loss best_param = param logToFile(logName, 'Best validation accuracy: %f' % best_val) logToFile(logName, getParamStr(param_key, best_param))
def run_model(modelStr, runMode, ckptDir, dataDir, sampleDir, overrideCkpt, numEpochs): print('Running model...') # choose the correct dataset if dataDir == '': if runMode == 'train': dataDir = TRAIN_DATA elif runMode == 'test': dataDir = TEST_DATA elif runMode == 'val': dataDir = VALIDATION_DATA if not os.path.exists(dataDir): print( 'Please specify a valid data directory, "%s" is not a valid directory. Exiting...' % dataDir) return else: print('Using dataset %s' % dataDir) print("Using checkpoint directory: {0}".format(ckptDir)) is_training = (runMode == 'train') batch_size = 1 if runMode == 'sample' else BATCH_SIZE numEpochs = numEpochs if is_training else 1 overrideCkpt = overrideCkpt if is_training else False printSeparator('Initializing %s/reading constants.py' % modelStr) input_sz = [IMG_DIM, IMG_DIM, 1] if modelStr == 'unet': output_sz = [IMG_DIM, IMG_DIM, 3] curModel = Unet(input_sz, output_sz) if modelStr == 'zhangnet': output_sz = [(IMG_DIM / 4)**2, 512] curModel = ZhangNet(input_sz, output_sz) printSeparator('Building ' + modelStr) curModel.create_model() curModel.metrics() print("Running {0} model for {1} epochs.".format(modelStr, numEpochs)) print("Reading in {0}-set filenames.".format(runMode)) global_step = tf.Variable( 0, trainable=False, name='global_step') #tf.contrib.framework.get_or_create_global_step() saver = tf.train.Saver(max_to_keep=numEpochs) step = 0 counter = 0 if runMode == 'sample': logDir, logName = None, None else: logDir, logName = createLog(runMode) # get the data file names and check if the @dataDir is a hdf5 file if is_training: dataset_filenames = getDataFileNames( dataDir, excludeFnames=['.filepart', 'test']) else: dataset_filenames = getDataFileNames(dataDir, excludeFnames=['.filepart']) if ('.jpg' in dataset_filenames[0]) or ('.png' in dataset_filenames[0]): print('The input data is detected to be raw images') NUM_SAMPLES = len(dataset_filenames) dataset_filenames = [dataset_filenames] printSeparator('Starting TF session') with tf.Session(config=GPU_CONFIG) as sess: print("Inititialized TF Session!") # load checkpoint if necessary i_stopped, found_ckpt = get_checkpoint(overrideCkpt, ckptDir, sess, saver) # save weights # printVars() # show_weights(getVar(sess, 'combine_3/kernel:0')) # exit(0) if runMode != 'sample': file_writer = tf.summary.FileWriter(logDir, graph=sess.graph, max_queue=10, flush_secs=30) if (not found_ckpt): if is_training: init_op = tf.global_variables_initializer( ) # tf.group(tf.initialize_all_variables(), tf.initialize_local_variables()) init_op.run() else: # Exit if no checkpoint to test] print('Valid checkpoint not found under %s, exiting...' % ckptDir) return if not is_training: numEpochs = i_stopped + 1 # run the network for epochCounter in range(i_stopped, numEpochs): batch_loss = [] printSeparator("Running epoch %d" % epochCounter) random.shuffle(dataset_filenames) for j, data_file in enumerate(dataset_filenames): mini_loss = [] for iter_val in range(DATA_LOAD_PARTITION): # Get data print('Reading data in %s, iter_val: %d...' % (data_file, iter_val)) # try: if runMode == 'sample' and PAPER_IMG_NAMES != None: input_batches, output_batches, imgNames = h52numpy( data_file, batch_sz=batch_size, iter_val=iter_val, mod_output=(modelStr == 'zhangnet'), fileNames=PAPER_IMG_NAMES) print(input_batches.shape) else: input_batches, output_batches, imgNames = h52numpy( data_file, batch_sz=batch_size, iter_val=iter_val, mod_output=(modelStr == 'zhangnet')) # except: # logToFile(logName, "File reading failed...") # continue print('Done reading, running the network (%d of %d)' % (j + 1, len(dataset_filenames))) bar = progressbar.ProgressBar( maxval=int(len(input_batches) / batch_size)) bar.start() count = 0 for dataIndx in range(0, len(imgNames), batch_size): in_batch = input_batches[dataIndx:dataIndx + batch_size] if output_batches is None: out_batch = None else: out_batch = output_batches[dataIndx:dataIndx + batch_size] imgName = imgNames[dataIndx:dataIndx + batch_size] # look at the images in the dataset (for debug usage) #for kk in range(batch_size): # numpy2jpg('tmp'+str(kk+dataIndx)+'.jpg', in_batch[kk,:,:,0], overlay=None, meanVal=LINE_MEAN, verbose=False) # numpy2jpg('KAK'+str(kk+dataIndx)+'.jpg', out_batch[kk,:,:], overlay=None, meanVal=1, verbose=False) #if dataIndx>batch_size*2: # exit(0) if runMode == 'sample': curModel.sample( sess, in_batch, out_batch, imgName=[os.path.join(sampleDir, imgName[0])]) if (NUM_SAMPLES - 1) == step: exit(0) else: summary_loss, loss = curModel.run( sess, in_batch, out_batch, is_training, imgName=os.path.join(sampleDir, imgName[0])) file_writer.add_summary(summary_loss, step) batch_loss.append(loss) mini_loss.append(loss) # Processed another batch step += 1 count += 1 bar.update(count) bar.finish() input_batches = None output_batches = None logToFile( logName, "Epoch %d Dataset #%d loss: %f" % (epochCounter, j, np.mean(mini_loss))) counter += 1 # run the sample images through the net to record the results to the Tensorflow (also locally stored) if is_training: img_summary = curModel.sample(sess, out2board=True, imgName=logDir + '/imgs') file_writer.add_summary(img_summary, counter) if counter % SAVE_CKPT_COUNTER == 0: save_checkpoint( ckptDir, sess, saver, i_stopped + int(counter / SAVE_CKPT_COUNTER)) test_loss = np.mean(batch_loss) logToFile(logName, "Epoch %d loss: %f" % (epochCounter, test_loss)) if is_training: # Checkpoint model - every epoch #save_checkpoint(ckptDir, sess, saver, epochCounter) pass elif runMode != 'sample': if runMode == 'val': # Update the file for choosing best hyperparameters curFile = open(curModel.config.val_filename, 'a') curFile.write("Validation set loss: {0}".format(test_loss)) curFile.write('\n') curFile.close()