Exemplo n.º 1
0
def train_and_model(model):
    # TODO LOOK ME this make show all tensor belong cpu or gpu
    # with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess, tf.device("/CPU:0"):
    with tf.Session() as sess:

        # checkpoint val
        saver = tf.train.Saver()

        # tensorboard
        train_writer = tf.summary.FileWriter(FLAGS.dir_train_tensorboard,
                                             sess.graph)
        test_writer = tf.summary.FileWriter(FLAGS.dir_test_tensorboard)

        # train step
        print("Train Start...")
        train_batch_config = Batch.Config()
        train_batch = Batch.Batch(train_batch_config)
        sess.run(model["init_op"])

        for step in range(FLAGS.max_train_step + 1):
            key_list = [
                Batch.INPUT_DATA, Batch.OUTPUT_LABEL, Batch.OUTPUT_DATA
            ]

            data = train_batch.next_batch(FLAGS.batch_size, key_list)

            feed_dict = {
                model["X"]: data[Batch.INPUT_DATA],
                model["Y"]: data[Batch.OUTPUT_DATA],
                model["Y_label"]: data[Batch.OUTPUT_LABEL]
            }
            sess.run(model["train_op"], feed_dict)

            # print log
            if step % FLAGS.print_log_step_size == 0:
                summary_train, _acc, _cost = sess.run(
                    [model["summary"], model["batch_acc"], model["cost"]],
                    feed_dict=feed_dict)
                print(datetime.datetime.utcnow(), "train step: %d" % step,
                      "batch_acc:", _acc, "cost:", _cost)

            # checkpoint
            if step % FLAGS.checkpoint_step_size == 0:
                saver.save(sess, FLAGS.dir_train_checkpoint, global_step=step)

            # summary tensorboard
            if step % FLAGS.summary_step_size:
                train_writer.add_summary(summary=summary_train,
                                         global_step=step)

        # test step
        print("Test Start...")
        test_batch_config = Batch.Config()
        test_batch = Batch.Batch(test_batch_config)

        total_acc = 0.
        for step in range(FLAGS.max_test_step + 1):
            key_list = [
                Batch.INPUT_DATA, Batch.OUTPUT_LABEL, Batch.OUTPUT_DATA
            ]

            data = test_batch.next_batch(FLAGS.batch_size, key_list)

            feed_dict = {
                model["X"]: data[Batch.INPUT_DATA],
                model["Y"]: data[Batch.OUTPUT_DATA],
                model["Y_label"]: data[Batch.OUTPUT_LABEL]
            }
            # print("input:", data[Batch.INPUT_DATA])
            # print("output:", data[Batch.OUTPUT_DATA])
            # print("label:", data[Batch.OUTPUT_LABEL])

            summary_test, _acc = sess.run(
                [model["summary"], model["batch_acc"]], feed_dict=feed_dict)
            print(datetime.datetime.utcnow(), "test step: %d" % step,
                  "batch_acc: ", _acc)
            total_acc += _acc

            if step % FLAGS.print_log_step_size == 0:
                summary_test, _acc = sess.run(
                    [model["summary"], model["batch_acc"]],
                    feed_dict=feed_dict)
                # print(datetime.datetime.utcnow(), "test step: %d" % step
                #       , "batch_acc: ", _acc)

            test_writer.add_summary(summary=summary_test, global_step=step)

        print("test complete: total acc =",
              total_acc / (FLAGS.max_test_step + 1))
    return