コード例 #1
0
ファイル: train.py プロジェクト: Gs-001/quad
def main(_):
    """Create or load configuration and launch the trainer."""
    utility.set_up_logging()
    if not FLAGS.config:
        raise KeyError('You must specify a configuration.')
    logdir = FLAGS.logdir and os.path.expanduser(
        os.path.join(FLAGS.logdir, '{}-{}'.format(FLAGS.timestamp,
                                                  FLAGS.config)))
    try:
        config = utility.load_config(logdir)
    except IOError:
        config = tools.AttrDict(getattr(configs, FLAGS.config)())
        config = utility.save_config(config, logdir)
    for score in train(config, FLAGS.env_processes):
        tf.logging.info('Score {}.'.format(score))
コード例 #2
0
def define_simulation_graph(batch_env, algo_cls, config):
    """Define the algortihm and environment interaction.

  Args:
    batch_env: In-graph environments object.
    algo_cls: Constructor of a batch algorithm.
    config: Configuration object for the algorithm.

  Returns:
    Object providing graph elements via attributes.
  """
    # pylint: disable=unused-variable
    step = tf.Variable(0, False, dtype=tf.int32, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')
    should_log = tf.placeholder(tf.bool, name='should_log')
    do_report = tf.placeholder(tf.bool, name='do_report')
    force_reset = tf.placeholder(tf.bool, name='force_reset')
    algo = algo_cls(batch_env, step, is_training, should_log, config)
    done, score, summary = tools.simulate(batch_env, algo, should_log,
                                          force_reset)
    message = 'Graph contains {} trainable variables.'
    tf.logging.info(message.format(tools.count_weights()))
    # pylint: enable=unused-variable
    return tools.AttrDict(locals())