Beispiel #1
0
def run_hyperparam(modelStr, numEpochs):
    # Hyperparameter search
    import itertools

    #---------------------------------------------
    param_key = ['lr']
    lr = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5]
    #---------------------------------------------

    #----set variables to emulate run_model()-----
    runMode = 'train'
    ckptDir = 'hyperparam_tmpckpt'
    if numEpochs == NUM_EPOCHS:
        numEpochs = NUM_EPOCHS_HYPERPARAM
    #---------------------------------------------

    ps = [lr]
    params = list(itertools.product(*ps))

    logDir, logName = createLog('_hyperparam')

    count = 0
    best_val = 0
    best_param = None
    for param in params:
        # build a new computational graph
        tf.reset_default_graph()

        input_sz = [IMG_DIM, IMG_DIM, 1]
        output_sz = [IMG_DIM, IMG_DIM, 3]

        curModel = Unet(input_sz, output_sz, verbose=False)
        curModel.create_model()
        curModel.metrics()

        count += 1
        printSeparator('Running #%d of %d runs...' % (count, len(params)))
        print(getParamStr(param_key, param))

        with tf.Session(config=GPU_CONFIG) as sess:
            # train the network
            dataset_filenames = getDataFileNames(
                TRAIN_DATA, excludeFnames=['.filepart', 'test'])
            for i in range(numEpochs):
                random.shuffle(dataset_filenames)
                train_loss = run_one_epoch(sess,
                                           curModel,
                                           dataset_filenames,
                                           modelStr,
                                           is_training=True)
                print('#%d training loss: %f' % (i, train_loss))

            # run the trained network on the validation set
            dataset_filenames = getDataFileNames(VALIDATION_DATA)
            val_loss = run_one_epoch(sess,
                                     curModel,
                                     dataset_filenames,
                                     modelStr,
                                     is_training=False)

            logToFile(
                logName,
                'train loss: %f, val loss: %f\n' % (train_loss, val_loss))
            if best_val < val_loss:
                best_val = val_loss
                best_param = param

    logToFile(logName, 'Best validation accuracy: %f' % best_val)
    logToFile(logName, getParamStr(param_key, best_param))
Beispiel #2
0
def run_model(modelStr, runMode, ckptDir, dataDir, sampleDir, overrideCkpt,
              numEpochs):
    print('Running model...')

    # choose the correct dataset
    if dataDir == '':
        if runMode == 'train':
            dataDir = TRAIN_DATA
        elif runMode == 'test':
            dataDir = TEST_DATA
        elif runMode == 'val':
            dataDir = VALIDATION_DATA

    if not os.path.exists(dataDir):
        print(
            'Please specify a valid data directory, "%s" is not a valid directory. Exiting...'
            % dataDir)
        return
    else:
        print('Using dataset %s' % dataDir)

    print("Using checkpoint directory: {0}".format(ckptDir))

    is_training = (runMode == 'train')
    batch_size = 1 if runMode == 'sample' else BATCH_SIZE
    numEpochs = numEpochs if is_training else 1
    overrideCkpt = overrideCkpt if is_training else False

    printSeparator('Initializing %s/reading constants.py' % modelStr)
    input_sz = [IMG_DIM, IMG_DIM, 1]
    if modelStr == 'unet':
        output_sz = [IMG_DIM, IMG_DIM, 3]
        curModel = Unet(input_sz, output_sz)
    if modelStr == 'zhangnet':
        output_sz = [(IMG_DIM / 4)**2, 512]
        curModel = ZhangNet(input_sz, output_sz)

    printSeparator('Building ' + modelStr)
    curModel.create_model()
    curModel.metrics()

    print("Running {0} model for {1} epochs.".format(modelStr, numEpochs))

    print("Reading in {0}-set filenames.".format(runMode))

    global_step = tf.Variable(
        0, trainable=False,
        name='global_step')  #tf.contrib.framework.get_or_create_global_step()
    saver = tf.train.Saver(max_to_keep=numEpochs)
    step = 0
    counter = 0

    if runMode == 'sample':
        logDir, logName = None, None
    else:
        logDir, logName = createLog(runMode)

    # get the data file names and check if the @dataDir is a hdf5 file
    if is_training:
        dataset_filenames = getDataFileNames(
            dataDir, excludeFnames=['.filepart', 'test'])
    else:
        dataset_filenames = getDataFileNames(dataDir,
                                             excludeFnames=['.filepart'])
    if ('.jpg' in dataset_filenames[0]) or ('.png' in dataset_filenames[0]):
        print('The input data is detected to be raw images')
        NUM_SAMPLES = len(dataset_filenames)
        dataset_filenames = [dataset_filenames]

    printSeparator('Starting TF session')
    with tf.Session(config=GPU_CONFIG) as sess:
        print("Inititialized TF Session!")

        # load checkpoint if necessary
        i_stopped, found_ckpt = get_checkpoint(overrideCkpt, ckptDir, sess,
                                               saver)

        # save weights
        # printVars()
        # show_weights(getVar(sess, 'combine_3/kernel:0'))
        # exit(0)

        if runMode != 'sample':
            file_writer = tf.summary.FileWriter(logDir,
                                                graph=sess.graph,
                                                max_queue=10,
                                                flush_secs=30)

        if (not found_ckpt):
            if is_training:
                init_op = tf.global_variables_initializer(
                )  # tf.group(tf.initialize_all_variables(), tf.initialize_local_variables())
                init_op.run()
            else:
                # Exit if no checkpoint to test]
                print('Valid checkpoint not found under %s, exiting...' %
                      ckptDir)
                return

        if not is_training:
            numEpochs = i_stopped + 1

        # run the network
        for epochCounter in range(i_stopped, numEpochs):
            batch_loss = []
            printSeparator("Running epoch %d" % epochCounter)
            random.shuffle(dataset_filenames)

            for j, data_file in enumerate(dataset_filenames):
                mini_loss = []
                for iter_val in range(DATA_LOAD_PARTITION):
                    # Get data
                    print('Reading data in %s, iter_val: %d...' %
                          (data_file, iter_val))
                    # try:
                    if runMode == 'sample' and PAPER_IMG_NAMES != None:
                        input_batches, output_batches, imgNames = h52numpy(
                            data_file,
                            batch_sz=batch_size,
                            iter_val=iter_val,
                            mod_output=(modelStr == 'zhangnet'),
                            fileNames=PAPER_IMG_NAMES)
                        print(input_batches.shape)
                    else:
                        input_batches, output_batches, imgNames = h52numpy(
                            data_file,
                            batch_sz=batch_size,
                            iter_val=iter_val,
                            mod_output=(modelStr == 'zhangnet'))
                    # except:
                    #     logToFile(logName, "File reading failed...")
                    #     continue
                    print('Done reading, running the network (%d of %d)' %
                          (j + 1, len(dataset_filenames)))

                    bar = progressbar.ProgressBar(
                        maxval=int(len(input_batches) / batch_size))
                    bar.start()
                    count = 0
                    for dataIndx in range(0, len(imgNames), batch_size):
                        in_batch = input_batches[dataIndx:dataIndx +
                                                 batch_size]
                        if output_batches is None:
                            out_batch = None
                        else:
                            out_batch = output_batches[dataIndx:dataIndx +
                                                       batch_size]
                        imgName = imgNames[dataIndx:dataIndx + batch_size]

                        # look at the images in the dataset (for debug usage)
                        #for kk in range(batch_size):
                        #    numpy2jpg('tmp'+str(kk+dataIndx)+'.jpg', in_batch[kk,:,:,0], overlay=None, meanVal=LINE_MEAN, verbose=False)
                        #    numpy2jpg('KAK'+str(kk+dataIndx)+'.jpg', out_batch[kk,:,:], overlay=None, meanVal=1, verbose=False)
                        #if dataIndx>batch_size*2:
                        #    exit(0)

                        if runMode == 'sample':
                            curModel.sample(
                                sess,
                                in_batch,
                                out_batch,
                                imgName=[os.path.join(sampleDir, imgName[0])])
                            if (NUM_SAMPLES - 1) == step:
                                exit(0)
                        else:
                            summary_loss, loss = curModel.run(
                                sess,
                                in_batch,
                                out_batch,
                                is_training,
                                imgName=os.path.join(sampleDir, imgName[0]))

                            file_writer.add_summary(summary_loss, step)
                            batch_loss.append(loss)
                            mini_loss.append(loss)

                        # Processed another batch
                        step += 1
                        count += 1
                        bar.update(count)
                    bar.finish()

                    input_batches = None
                    output_batches = None

                logToFile(
                    logName, "Epoch %d Dataset #%d loss: %f" %
                    (epochCounter, j, np.mean(mini_loss)))

                counter += 1
                # run the sample images through the net to record the results to the Tensorflow (also locally stored)
                if is_training:
                    img_summary = curModel.sample(sess,
                                                  out2board=True,
                                                  imgName=logDir + '/imgs')
                    file_writer.add_summary(img_summary, counter)

                    if counter % SAVE_CKPT_COUNTER == 0:
                        save_checkpoint(
                            ckptDir, sess, saver,
                            i_stopped + int(counter / SAVE_CKPT_COUNTER))

            test_loss = np.mean(batch_loss)
            logToFile(logName, "Epoch %d loss: %f" % (epochCounter, test_loss))

            if is_training:
                # Checkpoint model - every epoch
                #save_checkpoint(ckptDir, sess, saver, epochCounter)
                pass
            elif runMode != 'sample':
                if runMode == 'val':
                    # Update the file for choosing best hyperparameters
                    curFile = open(curModel.config.val_filename, 'a')
                    curFile.write("Validation set loss: {0}".format(test_loss))
                    curFile.write('\n')
                    curFile.close()