예제 #1
0
 def test_exclude_by_regex(self):
   tf.Variable(tf.zeros((3, 2)), trainable=True)
   with tf.variable_scope('foo'):
     tf.Variable(tf.zeros((5, 2)), trainable=True)
     with tf.variable_scope('bar'):
       tf.Variable(tf.zeros((1, 2)), trainable=True)
   self.assertEqual(0, count_weights(exclude=r'.*'))
   self.assertEqual(6, count_weights(exclude=r'(^|/)foo/.*'))
   self.assertEqual(16, count_weights(exclude=r'.*/bar/.*'))
 def test_exclude_by_regex(self):
     tf.Variable(tf.zeros((3, 2)), trainable=True)
     with tf.variable_scope('foo'):
         tf.Variable(tf.zeros((5, 2)), trainable=True)
         with tf.variable_scope('bar'):
             tf.Variable(tf.zeros((1, 2)), trainable=True)
     self.assertEqual(0, count_weights(exclude=r'.*'))
     self.assertEqual(6, count_weights(exclude=r'(^|/)foo/.*'))
     self.assertEqual(16, count_weights(exclude=r'.*/bar/.*'))
예제 #3
0
 def test_non_default_graph(self):
   graph = tf.Graph()
   with graph.as_default():
     tf.Variable(tf.zeros((5, 3)), trainable=True)
     tf.Variable(tf.zeros((8, 2)), trainable=False)
   self.assertNotEqual(graph, tf.get_default_graph)
   self.assertEqual(15, count_weights(graph=graph))
예제 #4
0
 def test_restrict_invalid_scope(self):
   tf.Variable(tf.zeros((3, 2)), trainable=True)
   with tf.variable_scope('foo'):
     tf.Variable(tf.zeros((5, 2)), trainable=True)
     with tf.variable_scope('bar'):
       tf.Variable(tf.zeros((1, 2)), trainable=True)
   self.assertEqual(0, count_weights('bar'))
예제 #5
0
 def test_trainable_and_non_trainable(self):
   tf.Variable(tf.zeros((5, 3)), trainable=True)
   tf.Variable(tf.zeros((8, 2)), trainable=False)
   tf.Variable(tf.zeros((1, 1)), trainable=True)
   tf.Variable(tf.zeros((5,)), trainable=True)
   tf.Variable(tf.zeros((3, 1)), trainable=False)
   self.assertEqual(15 + 1 + 5, count_weights())
예제 #6
0
def define_simulation_graph(batch_env, algo_cls, config):
    """Define the algortihm and environment interaction.

    Args:
      batch_env: In-graph environments object.
      algo_cls: Constructor of a batch algorithm.
      config: Configuration object for the algorithm.

    Returns:
      Object providing graph elements via attributes.
    """
    # pylint: disable=unused-variable
    step = tf.Variable(0, False, dtype=tf.int32, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')
    should_log = tf.placeholder(tf.bool, name='should_log')
    do_report = tf.placeholder(tf.bool, name='do_report')
    force_reset = tf.placeholder(tf.bool, name='force_reset')
    # Extended
    is_optimizing_offense = tf.placeholder(tf.bool,
                                           name='is_optimizing_offense')
    algo = algo_cls(batch_env,
                    step,
                    is_training,
                    should_log,
                    config,
                    is_optimizing_offense=is_optimizing_offense)
    done, score, summary, gail_summary = tools.simulate(
        batch_env, algo, should_log, force_reset)
    message = 'Graph contains {} trainable variables.'
    tf.logging.info(message.format(tools.count_weights()))
    # pylint: enable=unused-variable
    return tools.AttrDict(locals())
 def test_restrict_invalid_scope(self):
     tf.Variable(tf.zeros((3, 2)), trainable=True)
     with tf.variable_scope('foo'):
         tf.Variable(tf.zeros((5, 2)), trainable=True)
         with tf.variable_scope('bar'):
             tf.Variable(tf.zeros((1, 2)), trainable=True)
     self.assertEqual(0, count_weights('bar'))
 def test_non_default_graph(self):
     graph = tf.Graph()
     with graph.as_default():
         tf.Variable(tf.zeros((5, 3)), trainable=True)
         tf.Variable(tf.zeros((8, 2)), trainable=False)
     self.assertNotEqual(graph, tf.get_default_graph)
     self.assertEqual(15, count_weights(graph=graph))
예제 #9
0
def define_simulation_graph(batch_env, algo_cls, config):
    """Define the algortihm and environment interaction.
  Simulate a single step in all environments. See more details in `tools.simulate`.

  Args:
    batch_env: In-graph environments object.
    algo_cls: Constructor of a batch algorithm.
    config: Configuration object for the algorithm.

  Returns:
    Object providing graph elements via attributes.
  """
    # pylint: disable=unused-variable
    step = tf.Variable(0, False, dtype=tf.int32, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')
    should_log = tf.placeholder(tf.bool, name='should_log')
    do_report = tf.placeholder(tf.bool, name='do_report')
    force_reset = tf.placeholder(tf.bool, name='force_reset')
    algo = algo_cls(batch_env, step, is_training, should_log, config)
    should_step = tf.placeholder(tf.bool, name='should_step')
    use_external_action = tf.placeholder(tf.bool, name='use_external_action')
    external_action = tf.placeholder(batch_env.action_info[0],
                                     shape=batch_env.action_info[1],
                                     name='external_action')
    done, score, summary = tools.simulate(batch_env, algo, should_step,
                                          use_external_action, external_action,
                                          should_log, force_reset)
    message = 'Graph contains {} trainable variables.'
    tf.logging.info(message.format(tools.count_weights()))
    # pylint: enable=unused-variable
    return tools.AttrDict(locals())
 def test_trainable_and_non_trainable(self):
     tf.Variable(tf.zeros((5, 3)), trainable=True)
     tf.Variable(tf.zeros((8, 2)), trainable=False)
     tf.Variable(tf.zeros((1, 1)), trainable=True)
     tf.Variable(tf.zeros((5, )), trainable=True)
     tf.Variable(tf.zeros((3, 1)), trainable=False)
     self.assertEqual(15 + 1 + 5, count_weights())
예제 #11
0
def testing(config, off_data, off_label, def_data, def_label, outdir):
    """

    Args
    ----
    config : Object providing configurations via attributes.

    Yields
    ------
    score : Evaluation scores.
    """
    # split into train and eval
    off_train_data, off_eval_data = np.split(off_data,
                                             [off_data.shape[0] * 9 // 10])
    off_train_label, off_eval_label = np.split(off_label,
                                               [off_data.shape[0] * 9 // 10])
    def_train_data, def_eval_data = np.split(def_data,
                                             [def_data.shape[0] * 9 // 10])
    def_train_label, def_eval_label = np.split(def_label,
                                               [def_data.shape[0] * 9 // 10])
    print(off_train_data.shape)
    print(off_eval_data.shape)
    print(off_train_label.shape)
    print(off_eval_label.shape)
    print(def_train_data.shape)
    print(def_eval_data.shape)
    print(def_train_label.shape)
    print(def_eval_label.shape)

    # graph
    tf.reset_default_graph()
    if FLAGS.config == 'offense':
        model = pretrain_model.PretrainOffense(config)
    elif FLAGS.config == 'defense':
        model = pretrain_model.PretrainDefense(config)
    else:
        raise ValueError('{} is not an available config'.format(FLAGS.config))

    message = 'Graph contains {} trainable variables.'
    tf.logging.info(message.format(tools.count_weights()))
    saver = utility.define_saver()
    sess_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=config.log_device_placement)
    sess_config.gpu_options.allow_growth = True

    with tf.Session(config=sess_config) as sess:
        if FLAGS.debug:
            sess = tf_debug.LocalCLIDebugWrapperSession(sess,
                                                        ui_type=FLAGS.ui_type)
        utility.initialize_variables(sess, saver, config.logdir, resume=True)
        vis_result(sess, model, off_train_data, off_train_label,
                   def_train_data, def_train_label,
                   os.path.join(outdir, 'train'), 3)
        vis_result(sess, model, off_eval_data, off_eval_label, def_eval_data,
                   def_eval_label, os.path.join(outdir, 'eval'), 3)
예제 #12
0
파일: utility.py 프로젝트: shamanez/agents
def define_simulation_graph(batch_env, algo_cls, config):
  """Define the algorithm and environment interaction.

  Args:
    batch_env: In-graph environments object.
    algo_cls: Constructor of a batch algorithm.
    config: Configuration object for the algorithm.

  Returns:
    Object providing graph elements via attributes.
  """
  # pylint: disable=unused-variable
  step = tf.Variable(0, False, dtype=tf.int32, name='global_step')
  is_training = tf.placeholder(tf.bool, name='is_training')
  should_log = tf.placeholder(tf.bool, name='should_log')
  do_report = tf.placeholder(tf.bool, name='do_report')
  force_reset = tf.placeholder(tf.bool, name='force_reset')
  algo = algo_cls(batch_env, step, is_training, should_log, config)
  done, score, summary = tools.simulate(
      batch_env, algo, should_log, force_reset)
  message = 'Graph contains {} trainable variables.'
  tf.logging.info(message.format(tools.count_weights()))
  # pylint: enable=unused-variable
  return tools.AttrDict(locals())
 def test_include_scopes(self):
     tf.Variable(tf.zeros((3, 2)), trainable=True)
     with tf.variable_scope('foo'):
         tf.Variable(tf.zeros((5, 2)), trainable=True)
     self.assertEqual(6 + 10, count_weights())
 def test_ignore_non_trainable(self):
     tf.Variable(tf.zeros((5, 3)), trainable=False)
     tf.Variable(tf.zeros((1, 1)), trainable=False)
     tf.Variable(tf.zeros((5, )), trainable=False)
     self.assertEqual(0, count_weights())
 def test_count_trainable(self):
     tf.Variable(tf.zeros((5, 3)), trainable=True)
     tf.Variable(tf.zeros((1, 1)), trainable=True)
     tf.Variable(tf.zeros((5, )), trainable=True)
     self.assertEqual(15 + 1 + 5, count_weights())
예제 #16
0
 def test_ignore_non_trainable(self):
   tf.Variable(tf.zeros((5, 3)), trainable=False)
   tf.Variable(tf.zeros((1, 1)), trainable=False)
   tf.Variable(tf.zeros((5,)), trainable=False)
   self.assertEqual(0, count_weights())
예제 #17
0
 def test_include_scopes(self):
   tf.Variable(tf.zeros((3, 2)), trainable=True)
   with tf.variable_scope('foo'):
     tf.Variable(tf.zeros((5, 2)), trainable=True)
   self.assertEqual(6 + 10, count_weights())
예제 #18
0
def train(config, data, label, outdir):
    """ Training and evaluation entry point yielding scores.

    Args
    ----
    config : Object providing configurations via attributes.

    Yields
    ------
    score : Evaluation scores.
    """
    # normalization
    env = BBallPretrainEnv()
    min_ = env.observation_space.low
    max_ = env.observation_space.high
    data = 2 * (data - min_) / (max_ - min_) - 1
    # split into train and eval
    train_data, eval_data = np.split(data, [data.shape[0] * 9 // 10])
    train_label, eval_label = np.split(label, [data.shape[0] * 9 // 10])
    print(train_data.shape)
    print(train_label.shape)
    print(eval_data.shape)
    print(eval_label.shape)

    # graph
    tf.reset_default_graph()
    if FLAGS.config == 'offense':
        model = pretrain_model.PretrainOffense(config)
    elif FLAGS.config == 'defense':
        model = pretrain_model.PretrainDefense(config)
    else:
        raise ValueError('{} is not an available config'.format(FLAGS.config))
    # model = config.model(config)
    message = 'Graph contains {} trainable variables.'
    tf.logging.info(message.format(tools.count_weights()))
    saver = utility.define_saver()
    sess_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=config.log_device_placement)
    sess_config.gpu_options.allow_growth = True
    # summary writter
    train_writter = tf.summary.FileWriter(os.path.join(config.logdir, 'train'),
                                          tf.get_default_graph())
    # summary writter
    eval_writter = tf.summary.FileWriter(os.path.join(config.logdir, 'eval'),
                                         tf.get_default_graph())
    with tf.Session(config=sess_config) as sess:
        if FLAGS.debug:
            sess = tf_debug.LocalCLIDebugWrapperSession(sess,
                                                        ui_type=FLAGS.ui_type)
        utility.initialize_variables(sess,
                                     saver,
                                     config.logdir,
                                     resume=FLAGS.resume)
        for epoch_idx in range(config.num_epochs):
            tf.logging.info('Number of epochs: {}'.format(epoch_idx))
            training(sess, model, train_data, train_label, config,
                     train_writter)
            evaluating(sess, model, eval_data, eval_label, config,
                       eval_writter)
            if (epoch_idx + 1) % config.checkpoint_every == 0:
                tf.gfile.MakeDirs(config.logdir)
                filename = os.path.join(config.logdir, 'model.ckpt')
                saver.save(sess, filename, (epoch_idx + 1) * config.batch_size)
예제 #19
0
 def test_count_trainable(self):
   tf.Variable(tf.zeros((5, 3)), trainable=True)
   tf.Variable(tf.zeros((1, 1)), trainable=True)
   tf.Variable(tf.zeros((5,)), trainable=True)
   self.assertEqual(15 + 1 + 5, count_weights())