示例#1
0
def main():

    # Define model parameters and options in dictionary of flags
    FLAGS = getFlags()

    # Initialize model
    model = Model(FLAGS)

    # Specify number of training steps
    training_steps = FLAGS.__dict__['training_steps']

    # Define feed dictionary and loss name for EarlyStoppingHook
    loss_name = "loss_stopping:0"
    start_step = FLAGS.__dict__['early_stopping_start']
    stopping_step = FLAGS.__dict__['early_stopping_step']
    tolerance = FLAGS.__dict__['early_stopping_tol']

    # Define saver which only keeps previous 3 checkpoints (default=10)
    scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=3))

    # Enable GPU
    if FLAGS.__dict__['use_gpu']:
        config = tf.ConfigProto()
        config = tf.ConfigProto(device_count={'GPU': 1})
        config.gpu_options.allow_growth = True
    else:
        config = tf.ConfigProto(device_count={'GPU': 0})

    if FLAGS.early_stopping:
        hooks = [
            tf.train.StopAtStepHook(last_step=training_steps),
            EarlyStoppingHook(loss_name,
                              tolerance=tolerance,
                              stopping_step=stopping_step,
                              start_step=start_step)
        ]
    else:
        hooks = [tf.train.StopAtStepHook(last_step=training_steps)]

    # Initialize TensorFlow monitored training session
    with tf.train.MonitoredTrainingSession(
            config=config,
            checkpoint_dir=os.path.join(FLAGS.__dict__['model_dir'],
                                        FLAGS.__dict__['checkpoint_dir']),
            hooks=hooks,
            save_summaries_steps=None,
            save_summaries_secs=None,
            save_checkpoint_secs=None,
            save_checkpoint_steps=FLAGS.__dict__['checkpoint_step'],
            scaffold=scaffold) as sess:

        # Set model session
        model.set_session(sess)

        # Train model
        model.train()

    print("\n[ TRAINING COMPLETE ]\n")

    # Create new session for model evaluation
    with tf.Session() as sess:

        # Restore network parameters from latest checkpoint
        saver = tf.train.Saver()
        saver.restore(
            sess,
            tf.train.latest_checkpoint(
                os.path.join(FLAGS.__dict__['model_dir'],
                             FLAGS.__dict__['checkpoint_dir'])))

        # Set model session using restored sess
        model.set_session(sess)

        # Initialize datasets
        model.initialize_datasets()

        # Reinitialize dataset handles
        model.reinitialize_handles()

        # Evaluate model
        print("[ Evaluating Model ]")
        #t_loss, v_loss, t_uq, v_uq = model.evaluate()
        t_loss, v_loss, t_uq, v_uq, t_l1, v_l1, t_l2, v_l2 = model.evaluate()

        print("\n\n[ Final Evaluations ]")
        print("Training loss: %.7f  [ UQ = %.7f ]" % (t_loss, t_uq))
        print("Validation loss: %.7f  [ UQ = %.7f ]\n" % (v_loss, v_uq))

        print(" ")
        print("Training relative loss:  %.7f [L1]    %.7f [L2]" % (t_l1, t_l2))
        print("Validation relative loss:  %.7f [L1]    %.7f [L2]\n" %
              (v_l1, v_l2))

        with open(
                os.path.join(FLAGS.__dict__['model_dir'], "final_losses.csv"),
                "w") as csvfile:
            csvwriter = csv.writer(csvfile,
                                   delimiter=' ',
                                   quotechar='|',
                                   quoting=csv.QUOTE_MINIMAL)
            csvwriter.writerow([t_loss, v_loss])
            csvwriter.writerow([t_uq, v_uq])
            csvwriter.writerow([t_l1, v_l1])
            csvwriter.writerow([t_l2, v_l2])
示例#2
0

###############################################################################

timeStart = '2015-06-30 00:00:00'
timeEnd = '2015-07-30 00:00:00'
data = getFrameData(timeStart, timeEnd)

#dispArt(data)

#print data.head()
start = data.iloc[0].release_time
day10Times = 0
#print start

flagArr = flags.getFlags(conn)

idxToDrop = []
for idx, val in data.iterrows():
    #设置flag
    row = getByMap(val.word)
    if row and inArray(flagArr, row[3]) == False:
        idxToDrop.append(idx)
    ##设置时间的阶段,或许绘图自动实现
    #print val.release_time, "_____", ((val.release_time - start).days)
    if (val.release_time - start).days >= 10:
        day10Times = day10Times + 1
        start = val.release_time
    data.set_value(idx, '10_count', day10Times)
data = data.drop(idxToDrop)
#print data
示例#3
0
            hires_data = np.load(data_dir + 'hires_data_' + str(i) + '.npy')
            hires_data[hires_out_of_domain] = 0.0
            #hires_soln = np.load(soln_dir + 'hires_solution_' + str(i) + '.npy')

            if RESCALE:
                ## Rescale data and solutions
                hires_scaling = np.max(np.abs(hires_data))
                hires_data = hires_data / hires_scaling
                #hires_soln = hires_soln/hires_scaling

            np.save(data_dir + 'hires_data_' + str(i) + '.npy', hires_data)
            #np.save(soln_dir + 'hires_solution_' + str(i) + '.npy', hires_soln)


if __name__ == '__main__':
    FLAGS = getFlags()

    # Divide tasks into smaller pieces
    subdivision = 5

    #hires_mesh = np.load(FLAGS.mesh_dir + 'hires_mesh_' + str(0) + '.npy')


    def preprocess(d):
        preprocess_data(d, int(FLAGS.data_count / subdivision), FLAGS.data_dir,
                        FLAGS.mesh_dir, FLAGS.soln_dir)

    # Create multiprocessing pool
    NumProcesses = FLAGS.cpu_count
    pool = multiprocessing.Pool(processes=NumProcesses)
示例#4
0
# System #################################################################################################################################################

#Checks for windows system
if (os.name == 'nt'):
	from colorama import init
	init()
from termcolor import colored

# Flags ##################################################################################################################################################

blockFlag = ''		#If a result blocks is a:	-f phrase	-p paragraph			Default: everything
infoFlag = ''		#Additional info:			-l line		-w words				Default: no info
readFlag = ''		#How to read input:			-F file		-S stdin after EOF		Default: stdin, each line

fg.argv = sys.argv[2:]
blockFlag = fg.getFlags('-f','-p')
infoFlag = fg.getFlags('-l','-w')
readFlag = fg.getFlags('-F','-S')

filePaths = fg.getFiles()

#Checks for correct syntax
checkSyntax(blockFlag, infoFlag, readFlag)

# Match ##################################################################################################################################################

match = sys.argv[1] if (len(sys.argv) >= 2) else ''

if (blockFlag == '-f'):
	blockMatch = r'[^\.:!?]*' + match + r'.*?(\.\.\.|[\.:!?])'		#matches all phrases with match
elif(blockFlag == '-p'):