示例#1
0
def getInput(start, args):
    ds = args.ds.lower()
    print('{} using data source: {}'.format(strftime("%H:%M:%S"), args.ds))
    if ds == 'db':
        return input_fn.getInputs(
            start, TIME_SHIFT, feat_cols, MAX_STEP, args.parallel,
            args.prefetch, args.db_pool, args.db_host, args.db_port, args.db_pwd, args.vset or VSET)
    elif ds == 'bigquery':
        return input_bq.getInputs(start, TIME_SHIFT, feat_cols, MAX_STEP, TEST_BATCH_SIZE,
                                  vset=args.vset or VSET)
    elif ds == 'file':
        return input_file2.getInputs(
            args.dir, start, args.prefetch, args.vset or VSET, args.vol_size)
    return None
示例#2
0
文件: test10.py 项目: carusyte/tflab
def run(args):
    global bst_saver, bst_score, bst_file, bst_ckpt
    tf.logging.set_verbosity(tf.logging.INFO)
    keep_prob = tf.placeholder(tf.float32, [], name="keep_prob")
    with tf.Session() as sess:
        model = drnn.DRnnRegressorV4(dim=DIM,
                                     keep_prob=keep_prob,
                                     layer_width=LAYER_WIDTH,
                                     learning_rate=LEARNING_RATE)
        model_name = model.getName()
        print('{} using model: {}'.format(strftime("%H:%M:%S"), model_name))
        f = __file__
        testn = f[f.rfind('/') + 1:f.rindex('.py')]
        base_dir = '{}/{}_{}'.format(LOG_DIR, testn, model_name)
        training_dir = os.path.join(base_dir, 'training')
        summary_dir = os.path.join(training_dir, 'summary')
        checkpoint_file = os.path.join(training_dir, 'model.ckpt')
        bst_ckpt = os.path.join(base_dir, 'best', 'model.ckpt')
        saver = None
        summary_str = None
        d = None
        restored = False
        bno, epoch, bst_score = 0, 0, sys.maxint
        ckpt = tf.train.get_checkpoint_state(training_dir)

        if tf.gfile.Exists(training_dir):
            print("{} training folder exists".format(strftime("%H:%M:%S")))
            bst_file = open(os.path.join(base_dir, 'best_score'), 'r+')
            bst_file.seek(0)
            if ckpt and ckpt.model_checkpoint_path:
                print("{} found model checkpoint path: {}".format(
                    strftime("%H:%M:%S"), ckpt.model_checkpoint_path))
                # Extract from checkpoint filename
                bno = int(
                    os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
                print('{} resuming from last training, bno = {}'.format(
                    strftime("%H:%M:%S"), bno))
                d = input_fn.getInputs(bno + 1, TIME_SHIFT, feat_cols,
                                       MAX_STEP, args.parallel, args.prefetch,
                                       args.db_pool, args.db_host,
                                       args.db_port, args.db_pwd, args.vset
                                       or VSET)
                model.setNodes(d['uuids'], d['features'], d['labels'],
                               d['seqlens'])
                saver = tf.train.Saver(name="reg_saver")
                saver.restore(sess, ckpt.model_checkpoint_path)
                restored = True
                try:
                    bst_score = float(bst_file.readline().rstrip())
                    print('{} previous best score: {}'.format(
                        strftime("%H:%M:%S"), bst_score))
                except Exception:
                    print(
                        '{} not able to read best score. best_score file is invalid.'
                        .format(strftime("%H:%M:%S")))
                bst_file.seek(0)
                rbno = sess.run(tf.train.get_global_step())
                print(
                    '{} check restored global step: {}, previous batch no: {}'.
                    format(strftime("%H:%M:%S"), rbno, bno))
                if bno != rbno:
                    print(
                        '{} bno({}) inconsistent with global step({}). reset global step with bno.'
                        .format(strftime("%H:%M:%S"), bno, rbno))
                    gstep = tf.train.get_global_step(sess.graph)
                    sess.run(tf.assign(gstep, bno))
            else:
                print(
                    "{} model checkpoint path not found, cleaning training folder"
                    .format(strftime("%H:%M:%S")))
                tf.gfile.DeleteRecursively(training_dir)

        if not restored:
            d = input_fn.getInputs(bno + 1, TIME_SHIFT, feat_cols, MAX_STEP,
                                   args.parallel, args.prefetch, args.db_pool,
                                   args.db_host, args.db_port, args.db_pwd,
                                   args.vset or VSET)
            model.setNodes(d['uuids'], d['features'], d['labels'],
                           d['seqlens'])
            saver = tf.train.Saver(name="reg_saver")
            sess.run(tf.global_variables_initializer())
            tf.gfile.MakeDirs(training_dir)
            bst_file = open(os.path.join(base_dir, 'best_score'), 'w+')
        bst_saver = tf.train.Saver(name="bst_saver")

        train_handle, test_handle = sess.run(
            [d['train_iter'].string_handle(), d['test_iter'].string_handle()])

        summary, train_writer, test_writer = collect_summary(
            sess, model, summary_dir)
        test_summary_str = None
        while True:
            # bno = epoch*TEST_INTERVAL
            epoch = bno // TEST_INTERVAL
            if restored or bno % TEST_INTERVAL == 0:
                test_summary_str = validate(sess, model, summary, {
                    d['handle']: test_handle,
                    keep_prob: 1
                }, bno, epoch)
                restored = False
            try:
                print('{} training batch {}'.format(strftime("%H:%M:%S"),
                                                    bno + 1))
                summary_str, worst = sess.run(
                    [summary, model.worst, model.optimize], {
                        d['handle']: train_handle,
                        keep_prob: KEEP_PROB
                    })[:-1]
            except tf.errors.OutOfRangeError:
                print("End of Dataset.")
                break
            bno = bno + 1
            _, max_diff, predict, actual = worst[0], worst[1], worst[2], worst[
                3]
            print('{} bno {} max_diff {:3.4f} predict {} actual {}'.format(
                strftime("%H:%M:%S"), bno, max_diff, predict, actual))
            train_writer.add_summary(summary_str, bno)
            test_writer.add_summary(test_summary_str, bno)
            train_writer.flush()
            test_writer.flush()
            if bno == 1 or bno % SAVE_INTERVAL == 0:
                saver.save(sess,
                           checkpoint_file,
                           global_step=tf.train.get_global_step())
        # test last epoch
        test_summary_str = validate(sess, model, summary, {
            d['handle']: test_handle,
            keep_prob: 1
        }, bno, epoch)
        train_writer.add_summary(summary_str, bno)
        test_writer.add_summary(test_summary_str, bno)
        train_writer.flush()
        test_writer.flush()
        saver.save(sess,
                   checkpoint_file,
                   global_step=tf.train.get_global_step())
        # training finished, move to 'trained' folder
        trained = os.path.join(base_dir, 'trained')
        tf.gfile.MakeDirs(trained)
        tmp_dir = os.path.join(base_dir, strftime("%Y%m%d_%H%M%S"))
        os.rename(training_dir, tmp_dir)
        shutil.move(tmp_dir, trained)
        print('{} model is saved to {}'.format(strftime("%H:%M:%S"), trained))
        bst_file.close()
示例#3
0
文件: test4.py 项目: carusyte/tflab
def run():
    global args
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Session() as sess:
        model = base_model.SRnnRegressorV3(dim=DIM,
                                           layer_width=LAYER_WIDTH,
                                           learning_rate=LEARNING_RATE)
        model_name = model.getName()
        print('{} using model: {}'.format(strftime("%H:%M:%S"), model_name))
        f = __file__
        testn = f[f.rfind('/') + 1:f.rindex('.py')]
        base_dir = '{}/{}_{}'.format(LOG_DIR, testn, model_name)
        training_dir = os.path.join(base_dir, 'training')
        summary_dir = os.path.join(training_dir, 'summary')
        checkpoint_file = os.path.join(training_dir, 'model.ckpt')
        saver = None

        summary_str = None
        d = None
        bno = 0
        epoch = 0
        restored = False
        ckpt = tf.train.get_checkpoint_state(training_dir)

        if tf.gfile.Exists(training_dir):
            print("{} training folder exists".format(strftime("%H:%M:%S")))
            if ckpt and ckpt.model_checkpoint_path:
                print("{} found model checkpoint path: {}".format(
                    strftime("%H:%M:%S"), ckpt.model_checkpoint_path))
                # Extract from checkpoint filename
                bno = int(
                    os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
                print('{} resuming from last training, bno = {}'.format(
                    strftime("%H:%M:%S"), bno))
                d = input_fn.getInputs(bno + 1, TIME_SHIFT, k_cols, MAX_STEP,
                                       args.parallel, args.prefetch,
                                       args.db_pool, args.db_host,
                                       args.db_port, args.db_pwd, args.vset)
                model.setNodes(d['uuids'], d['features'], d['labels'],
                               d['seqlens'])
                saver = tf.train.Saver()
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('{} check restored global step: {}'.format(
                    strftime("%H:%M:%S"),
                    sess.run(tf.train.get_global_step())))
                restored = True
            else:
                print(
                    "{} model checkpoint path not found, cleaning training folder"
                    .format(strftime("%H:%M:%S")))
                tf.gfile.DeleteRecursively(training_dir)

        if not restored:
            d = input_fn.getInputs(bno + 1, TIME_SHIFT, k_cols, MAX_STEP,
                                   args.parallel, args.prefetch, args.db_pool,
                                   args.db_host, args.db_port, args.db_pwd,
                                   args.vset)
            model.setNodes(d['uuids'], d['features'], d['labels'],
                           d['seqlens'])
            saver = tf.train.Saver()
            sess.run(tf.global_variables_initializer())
            tf.gfile.MakeDirs(training_dir)

        train_handle, test_handle = sess.run(
            [d['train_iter'].string_handle(), d['test_iter'].string_handle()])

        train_feed = {d['handle']: train_handle}
        test_feed = {d['handle']: test_handle}

        summary, train_writer, test_writer = collect_summary(
            sess, model, summary_dir)
        test_summary_str = None
        while True:
            # bno = epoch*TEST_INTERVAL
            epoch = bno // TEST_INTERVAL
            if restored or bno % TEST_INTERVAL == 0:
                print('{} running on test set...'.format(strftime("%H:%M:%S")))
                mse, worst, test_summary_str = sess.run(
                    [model.cost, model.worst, summary], test_feed)
                uuid, max_diff, predict, actual = worst[0], worst[1], worst[
                    2], worst[3]
                print(
                    '{} Epoch {} diff {:3.5f} max_diff {:3.4f} predict {} actual {} uuid {}'
                    .format(strftime("%H:%M:%S"), epoch, math.sqrt(mse),
                            max_diff, predict, actual, uuid))
                restored = False
            try:
                print('{} training batch {}'.format(strftime("%H:%M:%S"),
                                                    bno + 1))
                summary_str, worst = sess.run(
                    [summary, model.worst, model.optimize], train_feed)[:-1]
            except tf.errors.OutOfRangeError:
                print("End of Dataset.")
                break
            bno = bno + 1
            _, max_diff, predict, actual = worst[0], worst[1], worst[2], worst[
                3]
            print('{} bno {} max_diff {:3.4f} predict {} actual {}'.format(
                strftime("%H:%M:%S"), bno, max_diff, predict, actual))
            train_writer.add_summary(summary_str, bno)
            test_writer.add_summary(test_summary_str, bno)
            train_writer.flush()
            test_writer.flush()
            if bno == 1 or bno % SAVE_INTERVAL == 0:
                saver.save(sess,
                           checkpoint_file,
                           global_step=tf.train.get_global_step())
        # test last epoch
        print('{} running on test set...'.format(strftime("%H:%M:%S")))
        mse, worst, test_summary_str = sess.run(
            [model.cost, model.worst, summary], test_feed)
        uuid, max_diff, predict, actual = worst[0], worst[1], worst[2], worst[
            3]
        print(
            '{} Epoch {} diff {:3.5f} max_diff {:3.4f} predict {} actual {} uuid {}'
            .format(strftime("%H:%M:%S"), epoch, math.sqrt(mse), max_diff,
                    predict, actual, uuid))
        train_writer.add_summary(summary_str, bno)
        test_writer.add_summary(test_summary_str, bno)
        train_writer.flush()
        test_writer.flush()
        saver.save(sess,
                   checkpoint_file,
                   global_step=tf.train.get_global_step())
        # training finished, move to 'trained' folder
        trained = os.path.join(base_dir, 'trained')
        tf.gfile.MakeDirs(trained)
        tmp_dir = os.path.join(base_dir, strftime("%Y%m%d_%H%M%S"))
        os.rename(training_dir, tmp_dir)
        shutil.move(tmp_dir, trained)
        print('{} model is saved to {}'.format(strftime("%H:%M:%S"), trained))