def test_exclude_by_regex(self): tf.Variable(tf.zeros((3, 2)), trainable=True) with tf.variable_scope('foo'): tf.Variable(tf.zeros((5, 2)), trainable=True) with tf.variable_scope('bar'): tf.Variable(tf.zeros((1, 2)), trainable=True) self.assertEqual(0, count_weights(exclude=r'.*')) self.assertEqual(6, count_weights(exclude=r'(^|/)foo/.*')) self.assertEqual(16, count_weights(exclude=r'.*/bar/.*'))
def test_non_default_graph(self): graph = tf.Graph() with graph.as_default(): tf.Variable(tf.zeros((5, 3)), trainable=True) tf.Variable(tf.zeros((8, 2)), trainable=False) self.assertNotEqual(graph, tf.get_default_graph) self.assertEqual(15, count_weights(graph=graph))
def test_restrict_invalid_scope(self): tf.Variable(tf.zeros((3, 2)), trainable=True) with tf.variable_scope('foo'): tf.Variable(tf.zeros((5, 2)), trainable=True) with tf.variable_scope('bar'): tf.Variable(tf.zeros((1, 2)), trainable=True) self.assertEqual(0, count_weights('bar'))
def test_trainable_and_non_trainable(self): tf.Variable(tf.zeros((5, 3)), trainable=True) tf.Variable(tf.zeros((8, 2)), trainable=False) tf.Variable(tf.zeros((1, 1)), trainable=True) tf.Variable(tf.zeros((5,)), trainable=True) tf.Variable(tf.zeros((3, 1)), trainable=False) self.assertEqual(15 + 1 + 5, count_weights())
def define_simulation_graph(batch_env, algo_cls, config): """Define the algortihm and environment interaction. Args: batch_env: In-graph environments object. algo_cls: Constructor of a batch algorithm. config: Configuration object for the algorithm. Returns: Object providing graph elements via attributes. """ # pylint: disable=unused-variable step = tf.Variable(0, False, dtype=tf.int32, name='global_step') is_training = tf.placeholder(tf.bool, name='is_training') should_log = tf.placeholder(tf.bool, name='should_log') do_report = tf.placeholder(tf.bool, name='do_report') force_reset = tf.placeholder(tf.bool, name='force_reset') # Extended is_optimizing_offense = tf.placeholder(tf.bool, name='is_optimizing_offense') algo = algo_cls(batch_env, step, is_training, should_log, config, is_optimizing_offense=is_optimizing_offense) done, score, summary, gail_summary = tools.simulate( batch_env, algo, should_log, force_reset) message = 'Graph contains {} trainable variables.' tf.logging.info(message.format(tools.count_weights())) # pylint: enable=unused-variable return tools.AttrDict(locals())
def define_simulation_graph(batch_env, algo_cls, config): """Define the algortihm and environment interaction. Simulate a single step in all environments. See more details in `tools.simulate`. Args: batch_env: In-graph environments object. algo_cls: Constructor of a batch algorithm. config: Configuration object for the algorithm. Returns: Object providing graph elements via attributes. """ # pylint: disable=unused-variable step = tf.Variable(0, False, dtype=tf.int32, name='global_step') is_training = tf.placeholder(tf.bool, name='is_training') should_log = tf.placeholder(tf.bool, name='should_log') do_report = tf.placeholder(tf.bool, name='do_report') force_reset = tf.placeholder(tf.bool, name='force_reset') algo = algo_cls(batch_env, step, is_training, should_log, config) should_step = tf.placeholder(tf.bool, name='should_step') use_external_action = tf.placeholder(tf.bool, name='use_external_action') external_action = tf.placeholder(batch_env.action_info[0], shape=batch_env.action_info[1], name='external_action') done, score, summary = tools.simulate(batch_env, algo, should_step, use_external_action, external_action, should_log, force_reset) message = 'Graph contains {} trainable variables.' tf.logging.info(message.format(tools.count_weights())) # pylint: enable=unused-variable return tools.AttrDict(locals())
def test_trainable_and_non_trainable(self): tf.Variable(tf.zeros((5, 3)), trainable=True) tf.Variable(tf.zeros((8, 2)), trainable=False) tf.Variable(tf.zeros((1, 1)), trainable=True) tf.Variable(tf.zeros((5, )), trainable=True) tf.Variable(tf.zeros((3, 1)), trainable=False) self.assertEqual(15 + 1 + 5, count_weights())
def testing(config, off_data, off_label, def_data, def_label, outdir): """ Args ---- config : Object providing configurations via attributes. Yields ------ score : Evaluation scores. """ # split into train and eval off_train_data, off_eval_data = np.split(off_data, [off_data.shape[0] * 9 // 10]) off_train_label, off_eval_label = np.split(off_label, [off_data.shape[0] * 9 // 10]) def_train_data, def_eval_data = np.split(def_data, [def_data.shape[0] * 9 // 10]) def_train_label, def_eval_label = np.split(def_label, [def_data.shape[0] * 9 // 10]) print(off_train_data.shape) print(off_eval_data.shape) print(off_train_label.shape) print(off_eval_label.shape) print(def_train_data.shape) print(def_eval_data.shape) print(def_train_label.shape) print(def_eval_label.shape) # graph tf.reset_default_graph() if FLAGS.config == 'offense': model = pretrain_model.PretrainOffense(config) elif FLAGS.config == 'defense': model = pretrain_model.PretrainDefense(config) else: raise ValueError('{} is not an available config'.format(FLAGS.config)) message = 'Graph contains {} trainable variables.' tf.logging.info(message.format(tools.count_weights())) saver = utility.define_saver() sess_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=config.log_device_placement) sess_config.gpu_options.allow_growth = True with tf.Session(config=sess_config) as sess: if FLAGS.debug: sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type=FLAGS.ui_type) utility.initialize_variables(sess, saver, config.logdir, resume=True) vis_result(sess, model, off_train_data, off_train_label, def_train_data, def_train_label, os.path.join(outdir, 'train'), 3) vis_result(sess, model, off_eval_data, off_eval_label, def_eval_data, def_eval_label, os.path.join(outdir, 'eval'), 3)
def define_simulation_graph(batch_env, algo_cls, config): """Define the algorithm and environment interaction. Args: batch_env: In-graph environments object. algo_cls: Constructor of a batch algorithm. config: Configuration object for the algorithm. Returns: Object providing graph elements via attributes. """ # pylint: disable=unused-variable step = tf.Variable(0, False, dtype=tf.int32, name='global_step') is_training = tf.placeholder(tf.bool, name='is_training') should_log = tf.placeholder(tf.bool, name='should_log') do_report = tf.placeholder(tf.bool, name='do_report') force_reset = tf.placeholder(tf.bool, name='force_reset') algo = algo_cls(batch_env, step, is_training, should_log, config) done, score, summary = tools.simulate( batch_env, algo, should_log, force_reset) message = 'Graph contains {} trainable variables.' tf.logging.info(message.format(tools.count_weights())) # pylint: enable=unused-variable return tools.AttrDict(locals())
def test_include_scopes(self): tf.Variable(tf.zeros((3, 2)), trainable=True) with tf.variable_scope('foo'): tf.Variable(tf.zeros((5, 2)), trainable=True) self.assertEqual(6 + 10, count_weights())
def test_ignore_non_trainable(self): tf.Variable(tf.zeros((5, 3)), trainable=False) tf.Variable(tf.zeros((1, 1)), trainable=False) tf.Variable(tf.zeros((5, )), trainable=False) self.assertEqual(0, count_weights())
def test_count_trainable(self): tf.Variable(tf.zeros((5, 3)), trainable=True) tf.Variable(tf.zeros((1, 1)), trainable=True) tf.Variable(tf.zeros((5, )), trainable=True) self.assertEqual(15 + 1 + 5, count_weights())
def test_ignore_non_trainable(self): tf.Variable(tf.zeros((5, 3)), trainable=False) tf.Variable(tf.zeros((1, 1)), trainable=False) tf.Variable(tf.zeros((5,)), trainable=False) self.assertEqual(0, count_weights())
def train(config, data, label, outdir): """ Training and evaluation entry point yielding scores. Args ---- config : Object providing configurations via attributes. Yields ------ score : Evaluation scores. """ # normalization env = BBallPretrainEnv() min_ = env.observation_space.low max_ = env.observation_space.high data = 2 * (data - min_) / (max_ - min_) - 1 # split into train and eval train_data, eval_data = np.split(data, [data.shape[0] * 9 // 10]) train_label, eval_label = np.split(label, [data.shape[0] * 9 // 10]) print(train_data.shape) print(train_label.shape) print(eval_data.shape) print(eval_label.shape) # graph tf.reset_default_graph() if FLAGS.config == 'offense': model = pretrain_model.PretrainOffense(config) elif FLAGS.config == 'defense': model = pretrain_model.PretrainDefense(config) else: raise ValueError('{} is not an available config'.format(FLAGS.config)) # model = config.model(config) message = 'Graph contains {} trainable variables.' tf.logging.info(message.format(tools.count_weights())) saver = utility.define_saver() sess_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=config.log_device_placement) sess_config.gpu_options.allow_growth = True # summary writter train_writter = tf.summary.FileWriter(os.path.join(config.logdir, 'train'), tf.get_default_graph()) # summary writter eval_writter = tf.summary.FileWriter(os.path.join(config.logdir, 'eval'), tf.get_default_graph()) with tf.Session(config=sess_config) as sess: if FLAGS.debug: sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type=FLAGS.ui_type) utility.initialize_variables(sess, saver, config.logdir, resume=FLAGS.resume) for epoch_idx in range(config.num_epochs): tf.logging.info('Number of epochs: {}'.format(epoch_idx)) training(sess, model, train_data, train_label, config, train_writter) evaluating(sess, model, eval_data, eval_label, config, eval_writter) if (epoch_idx + 1) % config.checkpoint_every == 0: tf.gfile.MakeDirs(config.logdir) filename = os.path.join(config.logdir, 'model.ckpt') saver.save(sess, filename, (epoch_idx + 1) * config.batch_size)
def test_count_trainable(self): tf.Variable(tf.zeros((5, 3)), trainable=True) tf.Variable(tf.zeros((1, 1)), trainable=True) tf.Variable(tf.zeros((5,)), trainable=True) self.assertEqual(15 + 1 + 5, count_weights())