def __init__(self, mark, epsilon=0.1, step_size=0.1, td_lambda=0.6):
        '''
        初期化
        arguments:
            どちらのマークのプレイヤーとして行動するか
            ソフト方策用のepsilon(前回の実装ではstaticに定義していた)
            TD(λ)のステップ数
            TD(λ)のλ
        '''
        self.mark = mark
        self.epsilon = epsilon
        self.step_size = step_size
        self.td_lambda = td_lambda

        # tensorflow Session
        self.sess = tf.InteractiveSession()
        # input placeholder
        self.x = tf.placeholder(tf.float32, shape=[None, 9])
        self.logits = model.inference(self.x, str(mark.to_int()))
   
        # initialize
        init_op = tf.initialize_all_variables()
        self.sess.run(init_op)
 
        # 価値テーブルを保持(Comを保存できるようにする)
        # stateはhashableにしてvalueのkeyにする
        self.value = defaultdict(lambda: 0.0)

        # TD(λ)で使う直前の状態
        self.previous_state = None
        # TD(λ)で使う現在の状態
        self.current_state = None
        # 各状態の適格度
        self.accumulated_weights = defaultdict(lambda: 0.0)

        self.training = True
        self.verbose = False
import model_mlp

NUM_CLASSES = 1
IMAGE_SIZE = 8
IMAGE_PIXELS = IMAGE_SIZE*IMAGE_SIZE*3

# 学習時のbatch size
BATCH_SIZE = 1

if __name__ == '__main__':
    filename_queue = tf.train.string_input_producer(["data/airquality.csv"])
    feature, data = load.mini_batch(filename_queue, 1)
    data_placeholder = tf.placeholder("float", shape=(None, 5))
    
    logits_maru = model_mlp.inference(data_placeholder, 'maru')
    logits_batsu = model_mlp.inference(data_placeholder, 'batsu')
    sess = tf.InteractiveSession()

    saver = tf.train.Saver()
    sess.run(tf.initialize_all_variables())

    # restore trained model
    ckpt = tf.train.get_checkpoint_state('train')
    print(ckpt.model_checkpoint_path)
    if ckpt and ckpt.model_checkpoint_path:
        print("Load checkpint.")
        saver.restore(sess, ckpt.model_checkpoint_path)
        global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    else:
        print('No checkpoint file found.')
Ejemplo n.º 3
0
import model_mlp

NUM_CLASSES = 1
IMAGE_SIZE = 8
IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE * 3

# 学習時のbatch size
BATCH_SIZE = 1

if __name__ == '__main__':
    filename_queue = tf.train.string_input_producer(["data/airquality.csv"])
    feature, data = load.mini_batch(filename_queue, 1)
    data_placeholder = tf.placeholder("float", shape=(None, 5))

    logits_maru = model_mlp.inference(data_placeholder, 'maru')
    logits_batsu = model_mlp.inference(data_placeholder, 'batsu')
    sess = tf.InteractiveSession()

    saver = tf.train.Saver()
    sess.run(tf.initialize_all_variables())

    # restore trained model
    ckpt = tf.train.get_checkpoint_state('train')
    print(ckpt.model_checkpoint_path)
    if ckpt and ckpt.model_checkpoint_path:
        print("Load checkpint.")
        saver.restore(sess, ckpt.model_checkpoint_path)
        global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    else:
        print('No checkpoint file found.')
Ejemplo n.º 4
0
def train():
    '''
    Train CNN_tiny for a number of steps.
    '''
    with tf.Graph().as_default():
        # globalなstep数
        global_step = tf.Variable(0, trainable=False)

        # 教師データ
        filename_queue = tf.train.string_input_producer(
            ["data/airquality.csv"])
        datas, targets = load.mini_batch(filename_queue, BATCH_SIZE)

        # placeholder
        x = tf.placeholder(tf.float32, shape=[None, 5])
        y = tf.placeholder(tf.float32, shape=[None, 1])

        # graphのoutput
        logits = model.inference(x)

        debug_value = model.debug(logits)

        # loss graphのoutputとlabelを利用
        loss = model.loss(logits, y)

        # 学習オペレーション
        train_op = op.train(loss, global_step)

        # saver
        saver = tf.train.Saver(tf.all_variables())

        # サマリー
        summary_op = tf.merge_all_summaries()

        # 初期化オペレーション
        init_op = tf.initialize_all_variables()

        # Session
        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=LOG_DEVICE_PLACEMENT))
        sess.run(init_op)

        print("settion start.")

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # サマリーのライターを設定
        summary_writer = tf.train.SummaryWriter(TRAIN_DIR,
                                                graph_def=sess.graph_def)

        # model名
        model_name = '/model%s.ckpt' % (tdatetime.strftime('%Y%m%d%H%M%S'))

        # max_stepまで繰り返し学習
        for step in xrange(MAX_STEPS):
            start_time = time.time()
            a, b = sess.run([datas, targets])
            _, loss_value, predict_value = sess.run(
                [train_op, loss, debug_value], feed_dict={
                    x: a,
                    y: b
                })

            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            # 100回ごと
            if step % 100 == 0:
                # stepごとの事例数 = mini batch size
                num_examples_per_step = BATCH_SIZE

                # 1秒ごとの事例数
                examples_per_sec = num_examples_per_step / duration

                # バッチごとの時間
                sec_per_batch = float(duration)

                # time, step数, loss, 1秒で実行できた事例数, バッチあたりの時間
                format_str = '$s: step %d, loss = %.2f (%.1f examples/sec; %.3f sec/batch)'
                print str(datetime.now()) + ': step' + str(
                    step) + ', loss= ' + str(loss_value) + ' ' + str(
                        examples_per_sec) + ' examples/sec; ' + str(
                            sec_per_batch) + ' sec/batch'

                print "x", a
                print "ground truth:", b
                print "predict: ", predict_value

            # 100回ごと
            if step % 100 == 0:
                pass
                #summary_str = sess.run(summary_op)
                # サマリーに書き込む
                #summary_writer.add_summary(summary_str, step)

            if step % 1000 == 0 or (step * 1) == MAX_STEPS:
                checkpoint_path = TRAIN_DIR + model_name
                saver.save(sess, checkpoint_path, global_step=step)

        coord.request_stop()
        coord.join(threads)
        sess.close()
def train():
    '''
    Train CNN_tiny for a number of steps.
    '''
    with tf.Graph().as_default():
        # globalなstep数
        global_step = tf.Variable(0, trainable=False)

        # 教師データ
        filename_queue = tf.train.string_input_producer(["data/airquality.csv"])
        datas, targets = load.mini_batch(filename_queue, BATCH_SIZE)

        # placeholder
        x = tf.placeholder(tf.float32, shape=[None, 5])
        y = tf.placeholder(tf.float32, shape=[None, 1])

        # graphのoutput
        logits = model.inference(x)

        debug_value = model.debug(logits)

        # loss graphのoutputとlabelを利用
        loss = model.loss(logits, y)

        # 学習オペレーション
        train_op = op.train(loss, global_step)

        # saver
        saver = tf.train.Saver(tf.all_variables())

        # サマリー
        summary_op = tf.merge_all_summaries()

        # 初期化オペレーション
        init_op = tf.initialize_all_variables()

        # Session
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=LOG_DEVICE_PLACEMENT))
        sess.run(init_op)

        print("settion start.")

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # サマリーのライターを設定
        summary_writer = tf.train.SummaryWriter(TRAIN_DIR, graph_def=sess.graph_def)

        # model名
        model_name = '/model%s.ckpt' % (tdatetime.strftime('%Y%m%d%H%M%S'))
   
        # max_stepまで繰り返し学習
        for step in xrange(MAX_STEPS):
            start_time = time.time()
            a, b = sess.run([datas, targets])
            _, loss_value, predict_value = sess.run([train_op, loss, debug_value], feed_dict={x: a, y: b})

            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            # 100回ごと
            if step % 100 == 0:
                # stepごとの事例数 = mini batch size
                num_examples_per_step = BATCH_SIZE

                # 1秒ごとの事例数
                examples_per_sec = num_examples_per_step / duration
                
                # バッチごとの時間
                sec_per_batch = float(duration)

                # time, step数, loss, 1秒で実行できた事例数, バッチあたりの時間
                format_str = '$s: step %d, loss = %.2f (%.1f examples/sec; %.3f sec/batch)'
                print str(datetime.now()) + ': step' + str(step) + ', loss= '+ str(loss_value) + ' ' + str(examples_per_sec) + ' examples/sec; ' + str(sec_per_batch) + ' sec/batch'

                print "x", a
                print "ground truth:", b
                print "predict: ", predict_value



            # 100回ごと
            if step % 100 == 0:
                pass
                #summary_str = sess.run(summary_op)
                # サマリーに書き込む
                #summary_writer.add_summary(summary_str, step)

            if step % 1000 == 0 or (step * 1) == MAX_STEPS:
                checkpoint_path = TRAIN_DIR + model_name
                saver.save(sess, checkpoint_path, global_step=step)

        coord.request_stop()
        coord.join(threads)
        sess.close()
import model_mlp

NUM_CLASSES = 1
IMAGE_SIZE = 8
IMAGE_PIXELS = IMAGE_SIZE*IMAGE_SIZE*3

# 学習時のbatch size
BATCH_SIZE = 1

if __name__ == '__main__':
    filename_queue = tf.train.string_input_producer(["data/airquality.csv"])
    feature, data = load.mini_batch(filename_queue, 1)
    data_placeholder = tf.placeholder("float", shape=(None, 5))
    
    logits = model_mlp.inference(data_placeholder)
    sess = tf.InteractiveSession()

    saver = tf.train.Saver()
    sess.run(tf.initialize_all_variables())

    # restore trained model
    ckpt = tf.train.get_checkpoint_state('train')
    print(ckpt.model_checkpoint_path)
    if ckpt and ckpt.model_checkpoint_path:
        print("Load checkpint.")
        saver.restore(sess, ckpt.model_checkpoint_path)
        global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    else:
        print('No checkpoint file found.')
        quit()