Ejemplo n.º 1
0
def define_simulation_graph(batch_env, algo_cls, config):
    """Define the algortihm and environment interaction.
  Simulate a single step in all environments. See more details in `tools.simulate`.

  Args:
    batch_env: In-graph environments object.
    algo_cls: Constructor of a batch algorithm.
    config: Configuration object for the algorithm.

  Returns:
    Object providing graph elements via attributes.
  """
    # pylint: disable=unused-variable
    step = tf.Variable(0, False, dtype=tf.int32, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')
    should_log = tf.placeholder(tf.bool, name='should_log')
    do_report = tf.placeholder(tf.bool, name='do_report')
    force_reset = tf.placeholder(tf.bool, name='force_reset')
    algo = algo_cls(batch_env, step, is_training, should_log, config)
    should_step = tf.placeholder(tf.bool, name='should_step')
    use_external_action = tf.placeholder(tf.bool, name='use_external_action')
    external_action = tf.placeholder(batch_env.action_info[0],
                                     shape=batch_env.action_info[1],
                                     name='external_action')
    done, score, summary = tools.simulate(batch_env, algo, should_step,
                                          use_external_action, external_action,
                                          should_log, force_reset)
    message = 'Graph contains {} trainable variables.'
    tf.logging.info(message.format(tools.count_weights()))
    # pylint: enable=unused-variable
    return tools.AttrDict(locals())
Ejemplo n.º 2
0
def main(_):
    """ Create or load configuration and launch the trainer.
    """
    utility.set_up_logging()
    if not FLAGS.resume:
        logdir = FLAGS.logdir and os.path.expanduser(
            os.path.join(FLAGS.logdir, '{}-{}'.format(FLAGS.timestamp,
                                                      FLAGS.config)))
    else:
        logdir = FLAGS.logdir
    if FLAGS.vis:
        outdir = os.path.join(logdir, 'train_output')
    else:
        outdir = None

    try:
        config = utility.load_config(logdir)
    except IOError:
        if not FLAGS.config:
            raise KeyError('You must specify a configuration.')
        config = tools.AttrDict(getattr(configs, FLAGS.config)())
        config = utility.save_config(config, logdir)

    for score in train(config, FLAGS.env_processes, outdir):
        tf.logging.info('Score {}.'.format(score))
Ejemplo n.º 3
0
def main(_):
    """ Create or load configuration and launch the trainer.
    """
    if FLAGS.config == 'offense':
        data = np.load('bball_strategies/pretrain/data/off_obs.npy')
        label = np.load('bball_strategies/pretrain/data/off_actions.npy')
    elif FLAGS.config == 'defense':
        data = np.load('bball_strategies/pretrain/data/def_obs.npy')
        label = np.load('bball_strategies/pretrain/data/def_actions.npy')
    else:
        raise ValueError('{} is not an available config'.format(FLAGS.config))
    utility.set_up_logging()
    if not FLAGS.resume:
        logdir = FLAGS.logdir and os.path.expanduser(
            os.path.join(FLAGS.logdir, '{}-{}'.format(FLAGS.timestamp,
                                                      FLAGS.config)))
    else:
        logdir = FLAGS.logdir
    if FLAGS.vis:
        outdir = os.path.join(logdir, 'train_output')
    else:
        outdir = None
    try:
        config = utility.load_config(logdir)
    except IOError:
        if not FLAGS.config:
            raise KeyError('You must specify a configuration.')
        config = tools.AttrDict(getattr(configs, FLAGS.config)())
        config = utility.save_config(config, logdir)
    train(config, data, label, outdir)
Ejemplo n.º 4
0
def define_simulation_graph(batch_env, algo_cls, config):
    """Define the algortihm and environment interaction.

    Args:
      batch_env: In-graph environments object.
      algo_cls: Constructor of a batch algorithm.
      config: Configuration object for the algorithm.

    Returns:
      Object providing graph elements via attributes.
    """
    # pylint: disable=unused-variable
    step = tf.Variable(0, False, dtype=tf.int32, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')
    should_log = tf.placeholder(tf.bool, name='should_log')
    do_report = tf.placeholder(tf.bool, name='do_report')
    force_reset = tf.placeholder(tf.bool, name='force_reset')
    # Extended
    is_optimizing_offense = tf.placeholder(tf.bool,
                                           name='is_optimizing_offense')
    algo = algo_cls(batch_env,
                    step,
                    is_training,
                    should_log,
                    config,
                    is_optimizing_offense=is_optimizing_offense)
    done, score, summary, gail_summary = tools.simulate(
        batch_env, algo, should_log, force_reset)
    message = 'Graph contains {} trainable variables.'
    tf.logging.info(message.format(tools.count_weights()))
    # pylint: enable=unused-variable
    return tools.AttrDict(locals())
Ejemplo n.º 5
0
def test_samples():
    env = gym.make('SpaceInvaders-v0')
    config = tools.AttrDict(default_config())

    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        init.run()
        buffer = initialize_memory(sess, env, config)
        batch_transition = buffer.sample(config.batch_size)

        assert batch_transition.action.shape == (config.batch_size, )
        assert batch_transition.observ.shape == (32, 84, 84, 4)
Ejemplo n.º 6
0
    def test_initialize(self):
        env = gym.make('SpaceInvaders-v0')

        _config = tools.AttrDict(default_config())

        _structure = functools.partial(_config.network, _config,
                                       env.observation_space, env.action_space)
        _network = tf.make_template('network', _structure)
        _target = tf.make_template('target', _structure)
        network = _network()
        target = _target(network.picked_action)
        init = tf.global_variables_initializer()
        with self.test_session() as sess:
            sess.run(init)
Ejemplo n.º 7
0
def main():
    """ test
    """
    from bball_strategies.scripts.gail import configs
    from agents import tools
    import gym
    config = tools.AttrDict(configs.default())
    with config.unlocked:
        config.logdir = 'test'
    env = gym.make(config.env)
    D1 = Discriminator(config, env)
    D2 = Discriminator()
    dummy = tf.ones(shape=[128, 10, 14, 2])
    r = D2.get_rewards(dummy)
    print(r)
Ejemplo n.º 8
0
 def _define_config(self):
     # Start from the example configuration.
     locals().update(configs.default())
     # pylint: disable=unused-variable
     # General
     algorithm = algorithms.PPO
     num_agents = 2
     update_every = 4
     use_gpu = False
     # Network
     policy_layers = 20, 10
     value_layers = 20, 10
     # Optimization
     update_epochs_policy = 2
     update_epochs_value = 2
     # pylint: enable=unused-variable
     return tools.AttrDict(locals())
Ejemplo n.º 9
0
def main(_):
    FLAGS.logdir = '../../Log'
    FLAGS.config = 'pendulum'
    FLAGS.env_processes = False
    """Create or load configuration and launch the trainer."""
    utility.set_up_logging()
    if not FLAGS.config:
        raise KeyError('You must specify a configuration.')
    logdir = FLAGS.logdir and os.path.expanduser(
        os.path.join(FLAGS.logdir, '{}-{}'.format(FLAGS.timestamp,
                                                  FLAGS.config)))
    try:
        config = utility.load_config(logdir)
    except IOError:
        config = tools.AttrDict(getattr(configs, FLAGS.config)())
        config = utility.save_config(config, logdir)
    global globalConfig
    globalConfig = config
    for score in train(config, FLAGS.env_processes):
        tf.logging.info('Score {}.'.format(score))
Ejemplo n.º 10
0
def main(_):
    """ Create or load configuration and launch the trainer.
    """
    utility.set_up_logging()
    if FLAGS.resume:
        logdir = FLAGS.logdir
    else:
        logdir = FLAGS.logdir and os.path.expanduser(os.path.join(
            FLAGS.logdir, '{}-{}'.format(FLAGS.timestamp, FLAGS.config)))
    if FLAGS.vis:
        outdir = os.path.join(logdir, 'train_output')
    else:
        outdir = None

    if not FLAGS.config:
        raise KeyError('You must specify a configuration.')
    config = tools.AttrDict(getattr(configs, FLAGS.config)())
    config = utility.save_config(config, logdir)

    # collecting
    testing(config, FLAGS.env_processes, outdir)
Ejemplo n.º 11
0
def main(_):
    """ Create or load configuration and launch the trainer.
    """
    off_data = np.load('bball_strategies/pretrain/data/off_obs.npy')
    off_label = np.load('bball_strategies/pretrain/data/off_actions.npy')
    def_data = np.load('bball_strategies/pretrain/data/def_obs.npy')
    def_label = np.load('bball_strategies/pretrain/data/def_actions.npy')

    utility.set_up_logging()

    logdir = FLAGS.logdir
    try:
        config = utility.load_config(logdir)
    except IOError:
        if not FLAGS.config:
            raise KeyError('You must specify a configuration.')
        config = tools.AttrDict(getattr(configs, FLAGS.config)())
        config = utility.save_config(config, logdir)
    outdir = os.path.expanduser(os.path.join(FLAGS.logdir, 'vis'))

    vis_data(off_data, off_label, def_data, def_label, outdir, start_idx=0)
    testing(config, off_data, off_label, def_data, def_label, outdir)
Ejemplo n.º 12
0
def main(_):
    env = gym.make('SpaceInvaders-v0')
    env = wrap_deepmind(env)

    atari_actions = np.arange(env.action_space.n, dtype=np.int32)

    _config = tools.AttrDict(default_config())

    # Initialize networks.
    with tf.variable_scope('q_network'):
        q_network = ValueFunction(_config, env.observation_space,
                                  env.action_space)
    with tf.variable_scope('target'):
        target = ValueFunction(_config, env.observation_space,
                               env.action_space, q_network)
    # Initialize global step
    # Epsilon
    eps = np.linspace(_config.epsilon_start, _config.epsilon_end,
                      _config.epsilon_decay_steps)

    sess = make_session()
    initialize_variables(sess)
    saver, checkpoint_path = make_saver(sess)

    # Initialize memory
    memory = initialize_memory(sess, env, _config)
    # Initialize policy
    policy = eps_greedy_policy(q_network, env.action_space.n)

    total_step = sess.run(tf.train.get_global_step())
    print('total_step', total_step)

    for episode in range(_config.num_episodes):
        observ = env.reset()
        observ = atari_preprocess(sess, observ)
        observ = np.stack([observ] * 4, axis=2)
        for t in itertools.count():
            action_prob = policy(
                sess, observ, eps[min(total_step,
                                      _config.epsilon_decay_steps - 1)])
            action = np.random.choice(atari_actions, size=1, p=action_prob)[0]
            next_observ, reward, terminal, _ = env.step(action)
            # next_observ = atari_preprocess(sess, next_observ)
            next_observ = np.concatenate(
                [observ[..., 1:], next_observ[..., None]], axis=2)
            memory.append(
                transition(observ, reward, terminal, next_observ, action))

            batch_transition = memory.sample(_config.batch_size)
            best_actions = q_network.best_action(sess,
                                                 batch_transition.next_observ)
            target_values = target.estimate(sess, batch_transition.reward,
                                            batch_transition.terminal,
                                            batch_transition.next_observ,
                                            best_actions)

            loss = q_network.update_step(sess, batch_transition.observ,
                                         batch_transition.action,
                                         target_values)
            print('\r({}/{}) loss: {}'.format(total_step,
                                              _config.max_total_step_size,
                                              loss),
                  end='',
                  flush=True)

            if total_step % _config.update_target_estimator_every == 0:
                print('\nUpdate Target Network...')
                target.assign(sess)

            if terminal:
                break

            total_step += 1
        saver.save(sess,
                   checkpoint_path,
                   global_step=tf.train.get_global_step())