Exemple #1
0
def run_training(train_dir):
    """Run training"""

    resume = os.path.exists(train_dir)

    with tf.Graph().as_default():
        model = FeedModel()
        saver = tf.train.Saver()
        session = tf.Session()
        summary_writer = tf.summary.FileWriter(train_dir,
                                               graph_def=session.graph_def,
                                               flush_secs=10)

        if resume:
            print("Resuming: ", train_dir)
            saver.restore(session, tf.train.latest_checkpoint(train_dir))
        else:
            print("Starting new training: ", train_dir)
            session.run(model.init)

        run_inference = make_run_inference(session, model)
        get_q_values = make_get_q_values(session, model)

        experience_collector = ExperienceCollector()
        batcher = ExperienceBatcher(experience_collector, run_inference,
                                    get_q_values, STATE_NORMALIZE_FACTOR)

        test_experiences = experience_collector.collect(
            play.random_strategy, 100)

        for state_batch, targets, actions in batcher.get_batches_stepwise():

            global_step, _ = session.run(
                [model.global_step, model.train_op],
                feed_dict={
                    model.state_batch_placeholder: state_batch,
                    model.targets_placeholder: targets,
                    model.actions_placeholder: actions,
                })

            if global_step % 1e3 == 0 and global_step != 0:
                saver.save(session,
                           train_dir + "/checkpoint",
                           global_step=global_step)
                loss = write_summaries(session, batcher, model,
                                       test_experiences, summary_writer)
                print("Step:", global_step, "Loss:", loss)
Exemple #2
0
def run_training(train_dir):
  """Run training"""

  resume = os.path.exists(train_dir)

  with tf.Graph().as_default():
    model = FeedModel()
    saver = tf.train.Saver()
    session = tf.Session()
    summary_writer = tf.train.SummaryWriter(train_dir,
                                            graph_def=session.graph_def,
                                            flush_secs=10)

    if resume:
      print("Resuming: ", train_dir)
      saver.restore(session, tf.train.latest_checkpoint(train_dir))
    else:
      print("Starting new training: ", train_dir)
      session.run(model.init)

    run_inference = make_run_inference(session, model)
    get_q_values = make_get_q_values(session, model)

    experience_collector = ExperienceCollector()
    batcher = ExperienceBatcher(experience_collector, run_inference,
                                get_q_values, STATE_NORMALIZE_FACTOR)

    test_experiences = experience_collector.collect(play.random_strategy, 100)

    for state_batch, targets, actions in batcher.get_batches_stepwise():

      global_step, _ = session.run([model.global_step, model.train_op],
          feed_dict={
              model.state_batch_placeholder: state_batch,
              model.targets_placeholder: targets,
              model.actions_placeholder: actions,})

      if global_step % 10000 == 0 and global_step != 0:
        saver.save(session, train_dir + "/checkpoint", global_step=global_step)
        loss = write_summaries(session, batcher, model, test_experiences,
                               summary_writer)
        print("Step:", global_step, "Loss:", loss)