Ejemplo n.º 1
0
    def __build_rl_helper_n_agent(self, sess):
        """Build the RL helper and controller / agent.

    Args:
    * sess: TensorFlow session

    Returns:
    * rl_helper: RL helper
    * agent: RL controller / agent
    """

        # build an RL helper
        skip_head_n_tail = (self.dataset_name == 'cifar_10'
                            )  # skip head & tail layers on CIFAR-10
        rl_helper = RLHelper(sess, self.vars_full['maskable'],
                             skip_head_n_tail)

        # build an RL controller / agent
        s_dims = rl_helper.s_dims
        a_dims = 1
        nb_rlouts = FLAGS.ws_nb_rlouts
        buf_size = len(self.vars_full['maskable']) * FLAGS.ws_nb_rlouts_min
        a_lbnd = 0.0
        a_ubnd = 1.0
        agent = DdpgAgent(sess, s_dims, a_dims, nb_rlouts, buf_size, a_lbnd,
                          a_ubnd)

        return rl_helper, agent
Ejemplo n.º 2
0
def build_env_n_agent(sess):
  """Build the environment and an RL agent to solve it.

  Args:
  * sess: TensorFlow session

  Returns:
  * env: environment
  * agent: RL agent
  """

  env = Env()
  s_dims = FLAGS.nb_dims
  a_dims = FLAGS.nb_dims
  nb_rlouts = FLAGS.nb_rlouts
  buf_size = int(FLAGS.rlout_len * nb_rlouts * 0.25)
  a_lbnd = -1.0
  a_ubnd = 1.0
  agent = DdpgAgent(sess, s_dims, a_dims, nb_rlouts, buf_size, a_lbnd, a_ubnd)

  return env, agent
Ejemplo n.º 3
0
def build_env_n_agent(sess):
  """Build the environment and an RL agent to solve it.

  Args:
  * sess: TensorFlow session

  Returns:
  * env: environment
  * agent: RL agent
  """

  env = gym.make('Pendulum-v0')
  s_dims = env.observation_space.shape[-1]
  a_dims = env.action_space.shape[-1]
  buf_size = int(FLAGS.rlout_len * FLAGS.nb_rlouts * 0.25)
  a_lbnd = env.action_space.low[0]
  a_ubnd = env.action_space.high[0]
  agent = DdpgAgent(sess, s_dims, a_dims, FLAGS.nb_rlouts, buf_size, a_lbnd, a_ubnd)
  tf.logging.info('s_dims = %d, a_dims = %d' % (s_dims, a_dims))
  tf.logging.info('a_lbnd = %f, a_ubnd = %f' % (a_lbnd, a_ubnd))

  return env, agent
Ejemplo n.º 4
0
  def __init__(self,
               dataset_name,
               weights,
               statistics,
               bit_placeholders,
               ops,
               layerwise_tune_list,
               sess_train,
               sess_eval,
               saver_train,
               saver_eval,
               barrier_fn):
    """ By passing the ops in the learner, we do not need to build the graph
    again for training and testing.

    Args:
    * dataset_name: a string that indicates which dataset to use
    * weights: a list of Tensors, the weights of networks to quantize
    * statistics: a dict, recording the number of weights, activations e.t.c.
    * bit_placeholders: a dict of placeholder Tensors, the input of bits
    * ops: a dict of ops, including trian_op, eval_op e.t.c.
    * layerwise_tune_list: a tuple, in which [0] records the layerwise op and
                          [1] records the layerwise l2_norm
    * sess_train: a session for train
    * sess_eval: a session for eval
    * saver_train: a Tensorflow Saver for the training graph
    * saver_eval: a Tensorflow Saver for the eval graph
    * barrier_fn: a function that implements barrier
    """
    self.dataset_name = dataset_name
    self.weights = weights
    self.statistics = statistics
    self.bit_placeholders = bit_placeholders
    self.ops = ops
    self.layerwise_tune_ops, self.layerwise_diff = \
        layerwise_tune_list[0], layerwise_tune_list[1]
    self.sess_train = sess_train
    self.sess_eval = sess_eval
    self.saver_train = saver_train
    self.saver_eval = saver_eval
    self.auto_barrier = barrier_fn

    self.total_num_weights = sum(self.statistics['num_weights'])
    self.total_bits = self.total_num_weights * FLAGS.uql_equivalent_bits

    self.w_rl_helper = RLHelper(self.sess_train,
                                self.total_bits,
                                self.statistics['num_weights'],
                                self.weights,
                                random_layers=FLAGS.uql_enbl_random_layers)

    self.mgw_size = int(mgw.size()) if FLAGS.enbl_multi_gpu else 1
    self.tune_global_steps = int(FLAGS.uql_tune_global_steps / self.mgw_size)
    self.tune_global_disp_steps = int(FLAGS.uql_tune_disp_steps / self.mgw_size)

    # build the rl trianing graph
    with tf.Graph().as_default():
      config = tf.ConfigProto()
      config.gpu_options.visible_device_list = str(mgw.local_rank() \
          if FLAGS.enbl_multi_gpu else 0)
      self.sess_rl = tf.Session(config=config)

      # train an RL agent through multiple roll-outs
      self.s_dims = self.w_rl_helper.s_dims
      self.a_dims = 1
      buff_size = len(self.weights) * int(FLAGS.uql_nb_rlouts // 4)
      self.agent = DdpgAgent(self.sess_rl,
                             self.s_dims,
                             self.a_dims,
                             FLAGS.uql_nb_rlouts,
                             buff_size,
                             a_min=0.,
                             a_max=FLAGS.uql_w_bit_max-FLAGS.uql_w_bit_min)
Ejemplo n.º 5
0
  def __prune_rl(self): # pylint: disable=too-many-locals
    """ search pruning strategy with reinforcement learning"""
    tf.logging.info(
      'preserve lower bound: {}, preserve ratio: {}, preserve upper bound: {}'.format(
        self.lbound, FLAGS.cp_preserve_ratio, self.rbound))
    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = str(0) # pylint: disable=no-member
    buf_size = len(self.pruner.states) * FLAGS.cp_nb_rlouts_min
    nb_rlouts = FLAGS.cp_nb_rlouts
    self.agent = DdpgAgent(
      tf.Session(config=config),
      len(self.pruner.states.loc[0].tolist()),
      1,
      nb_rlouts,
      buf_size,
      self.lbound,
      self.rbound)
    self.agent.init()
    self.bestinfo = None
    reward_best = np.NINF  # pylint: disable=no-member

    for idx_rlout in range(FLAGS.cp_nb_rlouts):
      # execute roll-outs to obtain pruning ratios
      self.agent.init_rlout()
      states_n_actions = []
      self.create_pruner()
      self.pruner.initialize_state()
      self.pruner.extract_features()
      state = np.array(self.pruner.currentStates.loc[0].tolist())[None, :]

      start = timer()
      while True:
        tf.logging.info('state is {}'.format(state))
        action = self.agent.sess.run(self.agent.actions_noisy, feed_dict={self.agent.states: state})
        tf.logging.info('RL choosed preserv ratio: {}'.format(action))
        state_next, acc_flops, done, real_action = self.pruner.compress(action)
        tf.logging.info('Actural preserv ratio: {}'.format(real_action))
        states_n_actions += [(state, real_action * np.ones((1, 1)))]
        state = state_next[None, :]
        actor_loss, critic_loss, noise_std = self.agent.train()
        if done:
          break
      tf.logging.info('roll-out #%d: a-loss = %.2e | c-loss = %.2e | noise std. = %.2e'
                      % (idx_rlout, actor_loss, critic_loss, noise_std))

      reward = self.__calc_reward(acc_flops[0], acc_flops[1])

      rewards = reward * np.ones(len(self.pruner.states))
      self.agent.finalize_rlout(rewards)

      # record transactions for RL training
      strategy = []
      for idx, (state, action) in enumerate(states_n_actions):
        strategy.append(action[0, 0])
        if idx != len(states_n_actions) - 1:
          terminal = np.zeros((1, 1))
          state_next = states_n_actions[idx + 1][0]
        else:
          terminal = np.ones((1, 1))
          state_next = np.zeros_like(state)
        self.agent.record(state, action, reward, terminal, state_next)

      # record the best combination of pruning ratios
      if reward_best < reward:
        tf.logging.info('best reward updated: %.4f -> %.4f' % (reward_best, reward))
        reward_best = reward
        self.bestinfo = [strategy, acc_flops[0], acc_flops[1]]
        tf.logging.info("""The best pruned model occured with
                strategy: {},
                accuracy: {} and
                pruned ratio: {}""".format(self.bestinfo[0], self.bestinfo[1], self.bestinfo[2]))

      tf.logging.info('automatic channl pruning time cost: {}s'.format(timer() - start))