예제 #1
0
class FullPrecLearner(AbstractLearner):  # pylint: disable=too-many-instance-attributes
  """Full-precision learner (no model compression applied)."""

  def __init__(self, sm_writer, model_helper, model_scope=None, enbl_dst=None):
    """Constructor function.

    Args:
    * sm_writer: TensorFlow's summary writer
    * model_helper: model helper with definitions of model & dataset
    * model_scope: name scope in which to define the model
    * enbl_dst: whether to create a model with distillation loss
    """

    # class-independent initialization
    super(FullPrecLearner, self).__init__(sm_writer, model_helper)

    # over-ride the model scope and distillation loss switch
    if model_scope is not None:
      self.model_scope = model_scope
    self.enbl_dst = enbl_dst if enbl_dst is not None else FLAGS.enbl_dst

    # class-dependent initialization
    if self.enbl_dst:
      self.helper_dst = DistillationHelper(sm_writer, model_helper, self.mpi_comm)
    self.__build(is_train=True)
    self.__build(is_train=False)

  def train(self):
    """Train a model and periodically produce checkpoint files."""

    # initialization
    self.sess_train.run(self.init_op)
    self.warm_start(self.sess_train)
    if FLAGS.enbl_multi_gpu:
      self.sess_train.run(self.bcast_op)

    # train the model through iterations and periodically save & evaluate the model
    time_prev = timer()
    for idx_iter in range(self.nb_iters_train):
      # train the model
      if (idx_iter + 1) % FLAGS.summ_step != 0:
        self.sess_train.run(self.train_op)
      else:
        __, summary, log_rslt = self.sess_train.run([self.train_op, self.summary_op, self.log_op])
        if self.is_primary_worker('global'):
          time_step = timer() - time_prev
          self.__monitor_progress(summary, log_rslt, idx_iter, time_step)
          time_prev = timer()

      # save & evaluate the model at certain steps
      if self.is_primary_worker('global') and (idx_iter + 1) % FLAGS.save_step == 0:
        self.__save_model(is_train=True)
        self.evaluate()

    # save the final model
    if self.is_primary_worker('global'):
      self.__save_model(is_train=True)
      self.__restore_model(is_train=False)
      self.__save_model(is_train=False)
      self.evaluate()

  def evaluate(self):
    """Restore a model from the latest checkpoint files and then evaluate it."""

    self.__restore_model(is_train=False)
    nb_iters = int(np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval))
    eval_rslts = np.zeros((nb_iters, len(self.eval_op)))
    self.dump_n_eval(outputs=None, action='init')
    for idx_iter in range(nb_iters):
      eval_rslts[idx_iter], outputs = self.sess_eval.run([self.eval_op, self.outputs_eval])
      self.dump_n_eval(outputs=outputs, action='dump')
    self.dump_n_eval(outputs=None, action='eval')
    for idx, name in enumerate(self.eval_op_names):
      tf.logging.info('%s = %.4e' % (name, np.mean(eval_rslts[:, idx])))

  def __build(self, is_train):  # pylint: disable=too-many-locals
    """Build the training / evaluation graph.

    Args:
    * is_train: whether to create the training graph
    """

    with tf.Graph().as_default():
      # TensorFlow session
      config = tf.ConfigProto()
      config.gpu_options.visible_device_list = str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
      sess = tf.Session(config=config)

      # data input pipeline
      with tf.variable_scope(self.data_scope):
        iterator = self.build_dataset_train() if is_train else self.build_dataset_eval()
        images, labels = iterator.get_next()
        tf.add_to_collection('images_final', images)

      # model definition - distilled model
      if self.enbl_dst:
        logits_dst = self.helper_dst.calc_logits(sess, images)

      # model definition - primary model
      with tf.variable_scope(self.model_scope):
        # forward pass
        logits = self.forward_train(images) if is_train else self.forward_eval(images)
        tf.add_to_collection('logits_final', logits)

        # loss & extra evalution metrics
        loss, metrics = self.calc_loss(labels, logits, self.trainable_vars)
        if self.enbl_dst:
          loss += self.helper_dst.calc_loss(logits, logits_dst)
        tf.summary.scalar('loss', loss)
        for key, value in metrics.items():
          tf.summary.scalar(key, value)

        # optimizer & gradients
        if is_train:
          self.global_step = tf.train.get_or_create_global_step()
          lrn_rate, self.nb_iters_train = self.setup_lrn_rate(self.global_step)
          optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
          if FLAGS.enbl_multi_gpu:
            optimizer = mgw.DistributedOptimizer(optimizer)
          grads = optimizer.compute_gradients(loss, self.trainable_vars)

      # TF operations & model saver
      if is_train:
        self.sess_train = sess
        with tf.control_dependencies(self.update_ops):
          self.train_op = optimizer.apply_gradients(grads, global_step=self.global_step)
        self.summary_op = tf.summary.merge_all()
        self.log_op = [lrn_rate, loss] + list(metrics.values())
        self.log_op_names = ['lr', 'loss'] + list(metrics.keys())
        self.init_op = tf.variables_initializer(self.vars)
        if FLAGS.enbl_multi_gpu:
          self.bcast_op = mgw.broadcast_global_variables(0)
        self.saver_train = tf.train.Saver(self.vars)
      else:
        self.sess_eval = sess
        self.eval_op = [loss] + list(metrics.values())
        self.eval_op_names = ['loss'] + list(metrics.keys())
        self.outputs_eval = logits
        self.saver_eval = tf.train.Saver(self.vars)

  def __save_model(self, is_train):
    """Save the model to checkpoint files for training or evaluation.

    Args:
    * is_train: whether to save a model for training
    """

    if is_train:
      save_path = self.saver_train.save(self.sess_train, FLAGS.save_path, self.global_step)
    else:
      save_path = self.saver_eval.save(self.sess_eval, FLAGS.save_path_eval)
    tf.logging.info('model saved to ' + save_path)

  def __restore_model(self, is_train):
    """Restore a model from the latest checkpoint files.

    Args:
    * is_train: whether to restore a model for training
    """

    save_path = tf.train.latest_checkpoint(os.path.dirname(FLAGS.save_path))
    if is_train:
      self.saver_train.restore(self.sess_train, save_path)
    else:
      self.saver_eval.restore(self.sess_eval, save_path)
    tf.logging.info('model restored from ' + save_path)

  def __monitor_progress(self, summary, log_rslt, idx_iter, time_step):
    """Monitor the training progress.

    Args:
    * summary: summary protocol buffer
    * log_rslt: logging operations' results
    * idx_iter: index of the training iteration
    * time_step: time step between two summary operations
    """

    # write summaries for TensorBoard visualization
    self.sm_writer.add_summary(summary, idx_iter)

    # compute the training speed
    speed = FLAGS.batch_size * FLAGS.summ_step / time_step
    if FLAGS.enbl_multi_gpu:
      speed *= mgw.size()

    # display monitored statistics
    log_str = ' | '.join(['%s = %.4e' % (name, value)
                          for name, value in zip(self.log_op_names, log_rslt)])
    tf.logging.info('iter #%d: %s | speed = %.2f pics / sec' % (idx_iter + 1, log_str, speed))
예제 #2
0
class UniformQuantTFLearner(AbstractLearner):  # pylint: disable=too-many-instance-attributes
    """Uniform quantization learner with TensorFlow's quantization APIs."""
    def __init__(self, sm_writer, model_helper):
        """Constructor function.

        Args:
        * sm_writer: TensorFlow's summary writer
        * model_helper: model helper with definitions of model & dataset
        """

        # class-independent initialization
        super(UniformQuantTFLearner, self).__init__(sm_writer, model_helper)

        # define scopes for full & uniform quantized models
        self.model_scope_full = 'model'
        self.model_scope_quan = 'quant_model'

        # download the pre-trained model
        if self.is_primary_worker('local'):
            self.download_model()  # pre-trained model is required
        self.auto_barrier()
        tf.logging.info('model files: ' + ', '.join(os.listdir('./models')))

        # detect unquantized activations nodes
        self.unquant_node_names = []
        if FLAGS.uqtf_enbl_manual_quant:
            self.unquant_node_names = find_unquant_act_nodes(
                model_helper, self.data_scope, self.model_scope_quan,
                self.mpi_comm)
        tf.logging.info('unquantized activation nodes: {}'.format(
            self.unquant_node_names))

        # class-dependent initialization
        if FLAGS.enbl_dst:
            self.helper_dst = DistillationHelper(sm_writer, model_helper,
                                                 self.mpi_comm)
        self.__build_train()
        self.__build_eval()

    def train(self):
        """Train a model and periodically produce checkpoint files."""

        # restore the full model from pre-trained checkpoints
        save_path = tf.train.latest_checkpoint(
            os.path.dirname(self.save_path_full))
        self.saver_full.restore(self.sess_train, save_path)

        # initialization
        self.sess_train.run([self.init_op, self.init_opt_op])
        if FLAGS.enbl_multi_gpu:
            self.sess_train.run(self.bcast_op)

        # train the model through iterations and periodically save & evaluate the model
        time_prev = timer()
        for idx_iter in range(self.nb_iters_train):
            # train the model
            if (idx_iter + 1) % FLAGS.summ_step != 0:
                self.sess_train.run(self.train_op)
            else:
                __, summary, log_rslt = self.sess_train.run(
                    [self.train_op, self.summary_op, self.log_op])
                if self.is_primary_worker('global'):
                    time_step = timer() - time_prev
                    self.__monitor_progress(summary, log_rslt, idx_iter,
                                            time_step)
                    time_prev = timer()

            # save the model at certain steps
            if self.is_primary_worker('global') and (idx_iter +
                                                     1) % FLAGS.save_step == 0:
                self.__save_model(is_train=True)
                self.evaluate()
            self.auto_barrier()

        # save the final model
        if self.is_primary_worker('global'):
            self.__save_model(is_train=True)
            self.__restore_model(is_train=False)
            self.__save_model(is_train=False)
            self.evaluate()

    def evaluate(self):
        """Restore a model from the latest checkpoint files and then evaluate it."""

        self.__restore_model(is_train=False)
        nb_iters = int(
            np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval))
        eval_rslts = np.zeros((nb_iters, len(self.eval_op)))
        self.dump_n_eval(outputs=None, action='init')
        for idx_iter in range(nb_iters):
            if (idx_iter + 1) % 100 == 0:
                tf.logging.info('process the %d-th mini-batch for evaluation' %
                                (idx_iter + 1))
            eval_rslts[idx_iter], outputs = self.sess_eval.run(
                [self.eval_op, self.outputs_eval])
            self.dump_n_eval(outputs=outputs, action='dump')
        self.dump_n_eval(outputs=None, action='eval')
        for idx, name in enumerate(self.eval_op_names):
            tf.logging.info('%s = %.4e' % (name, np.mean(eval_rslts[:, idx])))

    def __build_train(self):  # pylint: disable=too-many-locals,too-many-statements
        """Build the training graph."""

        with tf.Graph().as_default() as graph:
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            config.gpu_options.allow_growth = True  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()

            # model definition - uniform quantized model - part 1
            with tf.variable_scope(self.model_scope_quan):
                logits_quan = self.forward_train(images)
                if not isinstance(logits_quan, dict):
                    outputs = tf.nn.softmax(logits_quan)
                else:
                    outputs = tf.nn.softmax(logits_quan['cls_pred'])
                tf.contrib.quantize.experimental_create_training_graph(
                    weight_bits=FLAGS.uqtf_weight_bits,
                    activation_bits=FLAGS.uqtf_activation_bits,
                    quant_delay=FLAGS.uqtf_quant_delay,
                    freeze_bn_delay=FLAGS.uqtf_freeze_bn_delay,
                    scope=self.model_scope_quan)
                for node_name in self.unquant_node_names:
                    insert_quant_op(graph, node_name, is_train=True)
                self.vars_quan = get_vars_by_scope(self.model_scope_quan)
                self.global_step = tf.train.get_or_create_global_step()
                self.saver_quan_train = tf.train.Saver(self.vars_quan['all'] +
                                                       [self.global_step])

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(sess, images)

            # model definition - full model
            with tf.variable_scope(self.model_scope_full):
                __ = self.forward_train(images)
                self.vars_full = get_vars_by_scope(self.model_scope_full)
                self.saver_full = tf.train.Saver(self.vars_full['all'])
                self.save_path_full = FLAGS.save_path

            # model definition - uniform quantized model - part 2
            with tf.variable_scope(self.model_scope_quan):
                # loss & extra evaluation metrics
                loss_bsc, metrics = self.calc_loss(labels, logits_quan,
                                                   self.vars_quan['trainable'])
                if not FLAGS.enbl_dst:
                    loss_fnl = loss_bsc
                else:
                    loss_fnl = loss_bsc + self.helper_dst.calc_loss(
                        logits_quan, logits_dst)
                tf.summary.scalar('loss_bsc', loss_bsc)
                tf.summary.scalar('loss_fnl', loss_fnl)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # learning rate schedule
                lrn_rate, self.nb_iters_train = self.setup_lrn_rate(
                    self.global_step)
                lrn_rate *= FLAGS.uqtf_lrn_rate_dcy

                # decrease the learning rate by a constant factor
                # if self.dataset_name == 'cifar_10':
                #  lrn_rate *= 1e-3
                # elif self.dataset_name == 'ilsvrc_12':
                #  lrn_rate *= 1e-4
                # else:
                #  raise ValueError('unrecognized dataset\'s name: ' + self.dataset_name)

                # obtain the full list of trainable variables & update operations
                self.vars_all = tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES, scope=self.model_scope_quan)
                self.trainable_vars_all = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES,
                    scope=self.model_scope_quan)
                self.update_ops_all = tf.get_collection(
                    tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_quan)

                # TF operations for initializing the uniform quantized model
                init_ops = []
                with tf.control_dependencies(
                    [tf.variables_initializer(self.vars_all)]):
                    for var_full, var_quan in zip(self.vars_full['all'],
                                                  self.vars_quan['all']):
                        init_ops += [var_quan.assign(var_full)]
                init_ops += [self.global_step.initializer]
                self.init_op = tf.group(init_ops)

                # TF operations for fine-tuning
                # optimizer_base = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
                optimizer_base = tf.train.AdamOptimizer(lrn_rate)
                if not FLAGS.enbl_multi_gpu:
                    optimizer = optimizer_base
                else:
                    optimizer = mgw.DistributedOptimizer(optimizer_base)
                grads = optimizer.compute_gradients(loss_fnl,
                                                    self.trainable_vars_all)
                with tf.control_dependencies(self.update_ops_all):
                    self.train_op = optimizer.apply_gradients(
                        grads, global_step=self.global_step)
                self.init_opt_op = tf.variables_initializer(
                    optimizer_base.variables())

            # TF operations for logging & summarizing
            self.sess_train = sess
            self.summary_op = tf.summary.merge_all()
            self.log_op = [lrn_rate, loss_fnl] + list(metrics.values())
            self.log_op_names = ['lr', 'loss'] + list(metrics.keys())
            if FLAGS.enbl_multi_gpu:
                self.bcast_op = mgw.broadcast_global_variables(0)

    def __build_eval(self):
        """Build the evaluation graph."""

        with tf.Graph().as_default() as graph:
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            config.gpu_options.allow_growth = True  # pylint: disable=no-member
            self.sess_eval = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_eval()
                images, labels = iterator.get_next()

            # model definition - uniform quantized model - part 1
            with tf.variable_scope(self.model_scope_quan):
                logits = self.forward_eval(images)
                if not isinstance(logits, dict):
                    outputs = tf.nn.softmax(logits)
                else:
                    outputs = tf.nn.softmax(logits['cls_pred'])
                tf.contrib.quantize.experimental_create_eval_graph(
                    weight_bits=FLAGS.uqtf_weight_bits,
                    activation_bits=FLAGS.uqtf_activation_bits,
                    scope=self.model_scope_quan)
                for node_name in self.unquant_node_names:
                    insert_quant_op(graph, node_name, is_train=False)
                vars_quan = get_vars_by_scope(self.model_scope_quan)
                global_step = tf.train.get_or_create_global_step()
                self.saver_quan_eval = tf.train.Saver(vars_quan['all'] +
                                                      [global_step])

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(
                    self.sess_eval, images)

            # model definition - uniform quantized model -part 2
            with tf.variable_scope(self.model_scope_quan):
                # loss & extra evaluation metrics
                loss, metrics = self.calc_loss(labels, logits,
                                               vars_quan['trainable'])
                if FLAGS.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits, logits_dst)

                # TF operations for evaluation
                self.eval_op = [loss] + list(metrics.values())
                self.eval_op_names = ['loss'] + list(metrics.keys())
                self.outputs_eval = logits

            # add input & output tensors to certain collections
            if not isinstance(images, dict):
                tf.add_to_collection('images_final', images)
            else:
                tf.add_to_collection('images_final', images['image'])
            if not isinstance(logits, dict):
                tf.add_to_collection('logits_final', logits)
            else:
                tf.add_to_collection('logits_final', logits['cls_pred'])

    def __save_model(self, is_train):
        """Save the current model for training or evaluation.

        Args:
        * is_train: whether to save a model for training
        """

        if is_train:
            save_path = self.saver_quan_train.save(self.sess_train,
                                                   FLAGS.uqtf_save_path,
                                                   self.global_step)
        else:
            save_path = self.saver_quan_eval.save(self.sess_eval,
                                                  FLAGS.uqtf_save_path_eval)
        tf.logging.info('model saved to ' + save_path)

    def __restore_model(self, is_train):
        """Restore a model from the latest checkpoint files.

        Args:
        * is_train: whether to restore a model for training
        """

        save_path = tf.train.latest_checkpoint(
            os.path.dirname(FLAGS.uqtf_save_path))
        if is_train:
            self.saver_quan_train.restore(self.sess_train, save_path)
        else:
            self.saver_quan_eval.restore(self.sess_eval, save_path)
        tf.logging.info('model restored from ' + save_path)

    def __monitor_progress(self, summary, log_rslt, idx_iter, time_step):
        """Monitor the training progress.

        Args:
        * summary: summary protocol buffer
        * log_rslt: logging operations' results
        * idx_iter: index of the training iteration
        * time_step: time step between two summary operations
        """

        # write summaries for TensorBoard visualization
        self.sm_writer.add_summary(summary, idx_iter)

        # compute the training speed
        speed = FLAGS.batch_size * FLAGS.summ_step / time_step
        if FLAGS.enbl_multi_gpu:
            speed *= mgw.size()

        # display monitored statistics
        log_str = ' | '.join([
            '%s = %.4e' % (name, value)
            for name, value in zip(self.log_op_names, log_rslt)
        ])
        tf.logging.info('iter #%d: %s | speed = %.2f pics / sec' %
                        (idx_iter + 1, log_str, speed))
예제 #3
0
class WeightSparseLearner(AbstractLearner):  # pylint: disable=too-many-instance-attributes
    """Weight sparsification learner."""
    def __init__(self, sm_writer, model_helper):
        """Constructor function.

    Args:
    * sm_writer: TensorFlow's summary writer
    * model_helper: model helper with definitions of model & dataset
    """

        # class-independent initialization
        super(WeightSparseLearner, self).__init__(sm_writer, model_helper)

        # define the scope for masks
        self.mask_scope = 'mask'

        # compute the optimal pruning ratios
        pr_optimizer = PROptimizer(model_helper, self.mpi_comm)
        if FLAGS.ws_prune_ratio_prtl == 'optimal':
            if self.is_primary_worker('local'):
                self.download_model()  # pre-trained model is required
            self.auto_barrier()
            tf.logging.info('model files: ' +
                            ', '.join(os.listdir('./models')))
        self.var_names_n_prune_ratios = pr_optimizer.run()

        # class-dependent initialization
        if FLAGS.enbl_dst:
            self.helper_dst = DistillationHelper(sm_writer, model_helper,
                                                 self.mpi_comm)
        self.__build_train()
        self.__build_eval()

    def train(self):
        """Train a model and periodically produce checkpoint files."""

        # initialization
        self.sess_train.run(self.init_op)
        if FLAGS.enbl_multi_gpu:
            self.sess_train.run(self.bcast_op)

        # train the model through iterations and periodically save & evaluate the model
        last_mask_applied = False
        time_prev = timer()
        for idx_iter in range(self.nb_iters_train):
            # train the model
            if (idx_iter + 1) % FLAGS.summ_step != 0:
                self.sess_train.run(self.train_op)
            else:
                __, summary, log_rslt = self.sess_train.run(
                    [self.train_op, self.summary_op, self.log_op])
                if self.is_primary_worker('global'):
                    time_step = timer() - time_prev
                    self.__monitor_progress(summary, log_rslt, idx_iter,
                                            time_step)
                    time_prev = timer()

            # apply pruning
            if (idx_iter + 1) % FLAGS.ws_mask_update_step == 0:
                iter_ratio = float(idx_iter + 1) / self.nb_iters_train
                if iter_ratio >= FLAGS.ws_iter_ratio_beg:
                    if iter_ratio <= FLAGS.ws_iter_ratio_end:
                        self.sess_train.run([self.prune_op, self.init_opt_op])
                    elif not last_mask_applied:
                        last_mask_applied = True
                        self.sess_train.run([self.prune_op, self.init_opt_op])

            # save the model at certain steps
            if self.is_primary_worker('global') and (idx_iter +
                                                     1) % FLAGS.save_step == 0:
                self.__save_model()
                self.evaluate()

        # save the final model
        if self.is_primary_worker('global'):
            self.__save_model()
            self.evaluate()

    def evaluate(self):
        """Restore a model from the latest checkpoint files and then evaluate it."""

        self.__restore_model(is_train=False)
        nb_iters = int(np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size))
        eval_rslts = np.zeros((nb_iters, len(self.eval_op)))
        for idx_iter in range(nb_iters):
            eval_rslts[idx_iter] = self.sess_eval.run(self.eval_op)
        for idx, name in enumerate(self.eval_op_names):
            tf.logging.info('%s = %.4e' % (name, np.mean(eval_rslts[:, idx])))

    def __build_train(self):  # pylint: disable=too-many-locals
        """Build the training graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            if FLAGS.enbl_multi_gpu:
                config.gpu_options.visible_device_list = str(mgw.local_rank())  # pylint: disable=no-member
            else:
                config.gpu_options.visible_device_list = '0'  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(sess, images)

            # model definition - weight-sparsified model
            with tf.variable_scope(self.model_scope):
                # loss & extra evaluation metrics
                logits = self.forward_train(images)
                self.maskable_var_names = [
                    var.name for var in self.maskable_vars
                ]
                loss, metrics = self.calc_loss(labels, logits,
                                               self.trainable_vars)
                if FLAGS.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits, logits_dst)
                tf.summary.scalar('loss', loss)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # learning rate schedule
                self.global_step = tf.train.get_or_create_global_step()
                lrn_rate, self.nb_iters_train = setup_lrn_rate(
                    self.global_step, self.model_name, self.dataset_name)

                # overall pruning ratios of trainable & maskable variables
                pr_trainable = calc_prune_ratio(self.trainable_vars)
                pr_maskable = calc_prune_ratio(self.maskable_vars)
                tf.summary.scalar('pr_trainable', pr_trainable)
                tf.summary.scalar('pr_maskable', pr_maskable)

                # build masks and corresponding operations for weight sparsification
                self.masks, self.prune_op = self.__build_masks()

                # optimizer & gradients
                optimizer_base = tf.train.MomentumOptimizer(
                    lrn_rate, FLAGS.momentum)
                if not FLAGS.enbl_multi_gpu:
                    optimizer = optimizer_base
                else:
                    optimizer = mgw.DistributedOptimizer(optimizer_base)
                grads_origin = optimizer.compute_gradients(
                    loss, self.trainable_vars)
                grads_pruned = self.__calc_grads_pruned(grads_origin)

            # TF operations & model saver
            self.sess_train = sess
            with tf.control_dependencies(self.update_ops):
                self.train_op = optimizer.apply_gradients(
                    grads_pruned, global_step=self.global_step)
            self.summary_op = tf.summary.merge_all()
            self.log_op = [lrn_rate, loss, pr_trainable, pr_maskable] + list(
                metrics.values())
            self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_msk'] + list(
                metrics.keys())
            self.init_op = tf.variables_initializer(self.vars)
            self.init_opt_op = tf.variables_initializer(
                optimizer_base.variables())
            if FLAGS.enbl_multi_gpu:
                self.bcast_op = mgw.broadcast_global_variables(0)
            self.saver_train = tf.train.Saver(self.vars)

    def __build_eval(self):
        """Build the evaluation graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            if FLAGS.enbl_multi_gpu:
                config.gpu_options.visible_device_list = str(mgw.local_rank())  # pylint: disable=no-member
            else:
                config.gpu_options.visible_device_list = '0'  # pylint: disable=no-member
            self.sess_eval = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_eval()
                images, labels = iterator.get_next()

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(
                    self.sess_eval, images)

            # model definition - weight-sparsified model
            with tf.variable_scope(self.model_scope):
                # loss & extra evaluation metrics
                logits = self.forward_eval(images)
                loss, metrics = self.calc_loss(labels, logits,
                                               self.trainable_vars)
                if FLAGS.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits, logits_dst)

                # overall pruning ratios of trainable & maskable variables
                pr_trainable = calc_prune_ratio(self.trainable_vars)
                pr_maskable = calc_prune_ratio(self.maskable_vars)

                # TF operations for evaluation
                self.eval_op = [loss, pr_trainable, pr_maskable] + list(
                    metrics.values())
                self.eval_op_names = ['loss', 'pr_trn', 'pr_msk'] + list(
                    metrics.keys())
                self.saver_eval = tf.train.Saver(self.vars)

    def __build_masks(self):
        """build masks and corresponding operations for weight sparsification.

    Returns:
    * masks: list of masks for weight sparsification
    * prune_op: pruning operation
    """

        masks, prune_ops = [], []
        with tf.variable_scope(self.mask_scope):
            for var, var_name_n_prune_ratio in zip(
                    self.maskable_vars, self.var_names_n_prune_ratios):
                # obtain the dynamic pruning ratio
                assert var.name == var_name_n_prune_ratio[0], \
                    'unmatched variable names: %s vs. %s' % (var.name, var_name_n_prune_ratio[0])
                prune_ratio = self.__calc_prune_ratio_dyn(
                    var_name_n_prune_ratio[1])

                # create a mask and non-masked backup for each variable
                name = var.name.replace(':0', '_mask')
                mask = tf.get_variable(name,
                                       initializer=tf.ones(var.shape),
                                       trainable=False)
                name = var.name.replace(':0', '_var_bkup')
                var_bkup = tf.get_variable(name,
                                           initializer=var.initialized_value(),
                                           trainable=False)

                # create update operations
                mask_thres = tf.contrib.distributions.percentile(
                    tf.abs(var), prune_ratio * 100)
                var_bkup_update_op = var_bkup.assign(
                    tf.where(mask > 0.5, var, var_bkup))
                with tf.control_dependencies([var_bkup_update_op]):
                    mask_update_op = mask.assign(
                        tf.cast(tf.abs(var) > mask_thres, tf.float32))
                with tf.control_dependencies([mask_update_op]):
                    prune_op = var.assign(var_bkup * mask)

                # record pruning masks & operations
                masks += [mask]
                prune_ops += [prune_op]

        return masks, tf.group(prune_ops)

    def __calc_prune_ratio_dyn(self, prune_ratio_fnl):
        """Calculate the dynamic pruning ratio.

    Args:
    * prune_ratio_fnl: final pruning ratio

    Returns:
    * prune_ratio_dyn: dynamic pruning ratio
    """

        idx_iter_beg = int(self.nb_iters_train * FLAGS.ws_iter_ratio_beg)
        idx_iter_end = int(self.nb_iters_train * FLAGS.ws_iter_ratio_end)
        base = tf.cast(self.global_step - idx_iter_beg,
                       tf.float32) / (idx_iter_end - idx_iter_beg)
        base = tf.minimum(1.0, tf.maximum(0.0, base))
        prune_ratio_dyn = prune_ratio_fnl * (
            1.0 - tf.pow(1.0 - base, FLAGS.ws_prune_ratio_exp))

        return prune_ratio_dyn

    def __calc_grads_pruned(self, grads_origin):
        """Calculate the mask-pruned gradients.

    Args:
    * grads_origin: list of original gradients

    Returns:
    * grads_pruned: list of mask-pruned gradients
    """

        grads_pruned = []
        for grad in grads_origin:
            if grad[1].name not in self.maskable_var_names:
                grads_pruned += [grad]
            else:
                idx_mask = self.maskable_var_names.index(grad[1].name)
                grads_pruned += [(grad[0] * self.masks[idx_mask], grad[1])]

        return grads_pruned

    def __save_model(self):
        """Save the current model."""

        save_path = self.saver_train.save(self.sess_train, FLAGS.ws_save_path,
                                          self.global_step)
        tf.logging.info('model saved to ' + save_path)

    def __restore_model(self, is_train):
        """Restore a model from the latest checkpoint files.

    Args:
    * is_train: whether to restore a model for training
    """

        save_path = tf.train.latest_checkpoint(
            os.path.dirname(FLAGS.ws_save_path))
        if is_train:
            self.saver_train.restore(self.sess_train, save_path)
        else:
            self.saver_eval.restore(self.sess_eval, save_path)
        tf.logging.info('model restored from ' + save_path)

    def __monitor_progress(self, summary, log_rslt, idx_iter, time_step):
        """Monitor the training progress.

    Args:
    * summary: summary protocol buffer
    * log_rslt: logging operations' results
    * idx_iter: index of the training iteration
    * time_step: time step between two summary operations
    """

        # write summaries for TensorBoard visualization
        self.sm_writer.add_summary(summary, idx_iter)

        # compute the training speed
        speed = FLAGS.batch_size * FLAGS.summ_step / time_step
        if FLAGS.enbl_multi_gpu:
            speed *= mgw.size()

        # display monitored statistics
        log_str = ' | '.join([
            '%s = %.4e' % (name, value)
            for name, value in zip(self.log_op_names, log_rslt)
        ])
        tf.logging.info('iter #%d: %s | speed = %.2f pics / sec' %
                        (idx_iter + 1, log_str, speed))

    @property
    def maskable_vars(self):
        """List of all maskable variables."""

        return get_maskable_vars(self.trainable_vars)
예제 #4
0
class DisChnPrunedLearner(AbstractLearner):  # pylint: disable=too-many-instance-attributes
  """Discrimination-aware channel pruning learner."""

  def __init__(self, sm_writer, model_helper):
    """Constructor function.

    Args:
    * sm_writer: TensorFlow's summary writer
    * model_helper: model helper with definitions of model & dataset
    """

    # class-independent initialization
    super(DisChnPrunedLearner, self).__init__(sm_writer, model_helper)

    # define scopes for full & channel-pruned models
    self.model_scope_full = 'model'
    self.model_scope_prnd = 'pruned_model'

    # download the pre-trained model
    if self.is_primary_worker('local'):
      self.download_model()  # pre-trained model is required
    self.auto_barrier()
    tf.logging.info('model files: ' + ', '.join(os.listdir('./models')))

    # class-dependent initialization
    if FLAGS.enbl_dst:
      self.helper_dst = DistillationHelper(sm_writer, model_helper, self.mpi_comm)
    self.__build_train()
    self.__build_eval()

  def train(self):
    """Train a model and periodically produce checkpoint files."""

    # restore the full model from pre-trained checkpoints
    save_path = tf.train.latest_checkpoint(os.path.dirname(self.save_path_full))
    self.saver_full.restore(self.sess_train, save_path)

    # initialization
    self.sess_train.run([self.init_op, self.init_opt_op])
    self.sess_train.run(self.layer_init_opt_ops)  # initialization for layer-wise fine-tuning
    self.sess_train.run(self.block_init_opt_ops)  # initialization for block-wise fine-tuning
    if FLAGS.enbl_multi_gpu:
      self.sess_train.run(self.bcast_op)

    # choose discrimination-aware channels
    self.__choose_discr_chns()

    # fine-tune the model with chosen channels only
    time_prev = timer()
    for idx_iter in range(self.nb_iters_train):
      # train the model
      if (idx_iter + 1) % FLAGS.summ_step != 0:
        self.sess_train.run(self.train_op)
      else:
        __, summary, log_rslt = self.sess_train.run([self.train_op, self.summary_op, self.log_op])
        if self.is_primary_worker('global'):
          time_step = timer() - time_prev
          self.__monitor_progress(summary, log_rslt, idx_iter, time_step)
          time_prev = timer()

      # save the model at certain steps
      if self.is_primary_worker('global') and (idx_iter + 1) % FLAGS.save_step == 0:
        self.__save_model(is_train=True)
        self.evaluate()

    # save the final model
    if self.is_primary_worker('global'):
      self.__save_model(is_train=True)
      self.__restore_model(is_train=False)
      self.__save_model(is_train=False)
      self.evaluate()

  def evaluate(self):
    """Restore a model from the latest checkpoint files and then evaluate it."""

    self.__restore_model(is_train=False)
    nb_iters = int(np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval))
    eval_rslts = np.zeros((nb_iters, len(self.eval_op)))
    for idx_iter in range(nb_iters):
      eval_rslts[idx_iter] = self.sess_eval.run(self.eval_op)
    for idx, name in enumerate(self.eval_op_names):
      tf.logging.info('%s = %.4e' % (name, np.mean(eval_rslts[:, idx])))

  def __build_train(self):  # pylint: disable=too-many-locals,too-many-statements
    """Build the training graph."""

    with tf.Graph().as_default():
      # create a TF session for the current graph
      config = tf.ConfigProto()
      config.gpu_options.visible_device_list = str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
      sess = tf.Session(config=config)

      # data input pipeline
      with tf.variable_scope(self.data_scope):
        iterator = self.build_dataset_train()
        images, labels = iterator.get_next()

      # model definition - distilled model
      if FLAGS.enbl_dst:
        logits_dst = self.helper_dst.calc_logits(sess, images)

      # model definition - full model
      with tf.variable_scope(self.model_scope_full):
        __ = self.forward_train(images)
        self.vars_full = get_vars_by_scope(self.model_scope_full)
        self.saver_full = tf.train.Saver(self.vars_full['all'])
        self.save_path_full = FLAGS.save_path

      # model definition - channel-pruned model
      with tf.variable_scope(self.model_scope_prnd):
        logits_prnd = self.forward_train(images)
        self.vars_prnd = get_vars_by_scope(self.model_scope_prnd)
        self.maskable_var_names = [var.name for var in self.vars_prnd['maskable']]
        self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all'])

        # loss & extra evaluation metrics
        loss_bsc, metrics = self.calc_loss(labels, logits_prnd, self.vars_prnd['trainable'])
        if not FLAGS.enbl_dst:
          loss_fnl = loss_bsc
        else:
          loss_fnl = loss_bsc + self.helper_dst.calc_loss(logits_prnd, logits_dst)
        tf.summary.scalar('loss_bsc', loss_bsc)
        tf.summary.scalar('loss_fnl', loss_fnl)
        for key, value in metrics.items():
          tf.summary.scalar(key, value)

        # learning rate schedule
        self.global_step = tf.train.get_or_create_global_step()
        lrn_rate, self.nb_iters_train = self.setup_lrn_rate(self.global_step)

        # overall pruning ratios of trainable & maskable variables
        pr_trainable = calc_prune_ratio(self.vars_prnd['trainable'])
        pr_maskable = calc_prune_ratio(self.vars_prnd['maskable'])
        tf.summary.scalar('pr_trainable', pr_trainable)
        tf.summary.scalar('pr_maskable', pr_maskable)

        # create masks and corresponding operations for channel pruning
        self.masks = []
        self.mask_deltas = []
        self.mask_init_ops = []
        self.mask_updt_ops = []
        self.prune_ops = []
        for idx, var in enumerate(self.vars_prnd['maskable']):
          name = '/'.join(var.name.split('/')[1:]).replace(':0', '_mask')
          self.masks += [tf.get_variable(name, initializer=tf.ones(var.shape), trainable=False)]
          name = '/'.join(var.name.split('/')[1:]).replace(':0', '_mask_delta')
          self.mask_deltas += [tf.placeholder(tf.float32, shape=var.shape, name=name)]
          self.mask_init_ops += [self.masks[idx].assign(tf.zeros(var.shape))]
          self.mask_updt_ops += [self.masks[idx].assign_add(self.mask_deltas[idx])]
          self.prune_ops += [var.assign(var * self.masks[idx])]

        # build extra losses for regression & discrimination
        self.reg_losses, self.dis_losses, self.idxs_layer_to_block = \
          self.__build_extra_losses(labels)
        self.dis_losses += [loss_bsc]  # append discrimination-aware loss for the last block
        self.nb_layers = len(self.reg_losses)
        self.nb_blocks = len(self.dis_losses)
        for idx, reg_loss in enumerate(self.reg_losses):
          tf.summary.scalar('reg_loss_%d' % idx, reg_loss)
        for idx, dis_loss in enumerate(self.dis_losses):
          tf.summary.scalar('dis_loss_%d' % idx, dis_loss)

        # obtain the full list of trainable variables & update operations
        self.vars_all = tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope=self.model_scope_prnd)
        self.trainable_vars_all = tf.get_collection(
          tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.model_scope_prnd)
        self.update_ops_all = tf.get_collection(
          tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_prnd)

        # TF operations for initializing the channel-pruned model
        init_ops = []
        with tf.control_dependencies([tf.variables_initializer(self.vars_all)]):
          for var_full, var_prnd in zip(self.vars_full['all'], self.vars_prnd['all']):
            init_ops += [var_prnd.assign(var_full)]
        self.init_op = tf.group(init_ops)

        # TF operations for layer-wise, block-wise, and whole-network fine-tuning
        self.layer_train_ops, self.layer_init_opt_ops, self.grad_norms = self.__build_layer_ops()
        self.block_train_ops, self.block_init_opt_ops = self.__build_block_ops()
        self.train_op, self.init_opt_op = self.__build_network_ops(loss_fnl, lrn_rate)

      # TF operations for logging & summarizing
      self.sess_train = sess
      self.summary_op = tf.summary.merge_all()
      self.log_op = [lrn_rate, loss_fnl, pr_trainable, pr_maskable] + list(metrics.values())
      self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_msk'] + list(metrics.keys())
      if FLAGS.enbl_multi_gpu:
        self.bcast_op = mgw.broadcast_global_variables(0)

  def __build_eval(self):
    """Build the evaluation graph."""

    with tf.Graph().as_default():
      # create a TF session for the current graph
      config = tf.ConfigProto()
      config.gpu_options.visible_device_list = str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
      self.sess_eval = tf.Session(config=config)

      # data input pipeline
      with tf.variable_scope(self.data_scope):
        iterator = self.build_dataset_eval()
        images, labels = iterator.get_next()

      # model definition - distilled model
      if FLAGS.enbl_dst:
        logits_dst = self.helper_dst.calc_logits(self.sess_eval, images)

      # model definition - channel-pruned model
      with tf.variable_scope(self.model_scope_prnd):
        # loss & extra evaluation metrics
        logits = self.forward_eval(images)
        vars_prnd = get_vars_by_scope(self.model_scope_prnd)
        loss, metrics = self.calc_loss(labels, logits, vars_prnd['trainable'])
        if FLAGS.enbl_dst:
          loss += self.helper_dst.calc_loss(logits, logits_dst)

        # overall pruning ratios of trainable & maskable variables
        pr_trainable = calc_prune_ratio(vars_prnd['trainable'])
        pr_maskable = calc_prune_ratio(vars_prnd['maskable'])

        # TF operations for evaluation
        self.eval_op = [loss, pr_trainable, pr_maskable] + list(metrics.values())
        self.eval_op_names = ['loss', 'pr_trn', 'pr_msk'] + list(metrics.keys())
        self.saver_prnd_eval = tf.train.Saver(vars_prnd['all'])

      # add input & output tensors to certain collections
      tf.add_to_collection('images_final', images)
      tf.add_to_collection('logits_final', logits)

  def __build_extra_losses(self, labels):
    """Build extra losses for regression & discrimination.

    Args:
    * labels: one-hot label vectors

    Returns:
    * reg_losses: list of regression losses (one per layer)
    * dis_losses: list of discrimination-aware losses (one per layer)
    * idxs_layer_to_block: list of mappings from layer index to block index
    """

    # insert additional losses to intermediate layers
    pattern = re.compile('Conv2D$')
    core_ops_full = get_ops_by_scope_n_pattern(self.model_scope_full, pattern)
    core_ops_prnd = get_ops_by_scope_n_pattern(self.model_scope_prnd, pattern)
    nb_layers = len(core_ops_full)
    nb_blocks = int(FLAGS.dcp_nb_stages + 1)
    nb_layers_per_block = int(math.ceil((nb_layers + 1) / nb_blocks))
    reg_losses = []
    dis_losses = []
    idxs_layer_to_block = []
    for idx_layer in range(nb_layers):
      reg_losses += \
        [tf.nn.l2_loss(core_ops_full[idx_layer].outputs[0] - core_ops_prnd[idx_layer].outputs[0])]
      idxs_layer_to_block += [int(idx_layer / nb_layers_per_block)]
      if (idx_layer + 1) % nb_layers_per_block == 0:
        x = core_ops_prnd[idx_layer].outputs[0]
        x = tf.layers.batch_normalization(x, axis=3, training=True)
        x = tf.nn.relu(x)
        x = tf.reduce_mean(x, axis=[1, 2])
        x = tf.layers.dense(x, FLAGS.nb_classes)
        dis_losses += [tf.losses.softmax_cross_entropy(labels, x)]
    tf.logging.info('layer-to-block mapping: {}'.format(idxs_layer_to_block))

    return reg_losses, dis_losses, idxs_layer_to_block

  def __build_layer_ops(self):
    """Build layer-wise fine-tuning operations.

    Returns:
    * layer_train_ops: list of training operations for each layer
    * layer_init_opt_ops: list of initialization operations for each layer's optimizer
    * layer_grad_norms: list of gradient norm vectors for each layer
    """

    layer_train_ops = []
    layer_init_opt_ops = []
    grad_norms = []
    for idx, var_prnd in enumerate(self.vars_prnd['maskable']):
      optimizer_base = tf.train.AdamOptimizer(FLAGS.dcp_lrn_rate_adam)
      if not FLAGS.enbl_multi_gpu:
        optimizer = optimizer_base
      else:
        optimizer = mgw.DistributedOptimizer(optimizer_base)
      loss_all = self.reg_losses[idx] + self.dis_losses[self.idxs_layer_to_block[idx]]
      grads_origin = optimizer.compute_gradients(loss_all, [var_prnd])
      grads_pruned = self.__calc_grads_pruned(grads_origin)
      with tf.control_dependencies(self.update_ops_all):
        layer_train_ops += [optimizer.apply_gradients(grads_pruned)]
      layer_init_opt_ops += [tf.variables_initializer(optimizer_base.variables())]
      grad_norms += [tf.reduce_sum(grads_origin[0][0] ** 2, axis=[0, 1, 3])]

    return layer_train_ops, layer_init_opt_ops, grad_norms

  def __build_block_ops(self):
    """Build block-wise fine-tuning operations.

    Returns:
    * block_train_ops: list of training operations for each block
    * block_init_opt_ops: list of initialization operations for each block's optimizer
    """

    block_train_ops = []
    block_init_opt_ops = []
    for dis_loss in self.dis_losses:
      optimizer_base = tf.train.AdamOptimizer(FLAGS.dcp_lrn_rate_adam)
      if not FLAGS.enbl_multi_gpu:
        optimizer = optimizer_base
      else:
        optimizer = mgw.DistributedOptimizer(optimizer_base)
      loss_all = dis_loss + self.dis_losses[-1]  # current stage + final loss
      grads_origin = optimizer.compute_gradients(loss_all, self.trainable_vars_all)
      grads_pruned = self.__calc_grads_pruned(grads_origin)
      with tf.control_dependencies(self.update_ops_all):
        block_train_ops += [optimizer.apply_gradients(grads_pruned)]
      block_init_opt_ops += [tf.variables_initializer(optimizer_base.variables())]

    return block_train_ops, block_init_opt_ops

  def __build_network_ops(self, loss, lrn_rate):
    """Build network training operations.

    Returns:
    * train_op: training operation of the whole network
    * init_opt_op: initialization operation of the whole network's optimizer
    """

    optimizer_base = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
    if not FLAGS.enbl_multi_gpu:
      optimizer = optimizer_base
    else:
      optimizer = mgw.DistributedOptimizer(optimizer_base)
    loss_all = tf.add_n(self.dis_losses[:-1]) * 0 + loss  # all stages + final loss
    grads_origin = optimizer.compute_gradients(loss_all, self.trainable_vars_all)
    grads_pruned = self.__calc_grads_pruned(grads_origin)
    with tf.control_dependencies(self.update_ops_all):
      train_op = optimizer.apply_gradients(grads_pruned, global_step=self.global_step)
    init_opt_op = tf.variables_initializer(optimizer_base.variables())

    return train_op, init_opt_op

  def __calc_grads_pruned(self, grads_origin):
    """Calculate the mask-pruned gradients.

    Args:
    * grads_origin: list of original gradients

    Returns:
    * grads_pruned: list of mask-pruned gradients
    """

    grads_pruned = []
    for grad in grads_origin:
      if grad[1].name not in self.maskable_var_names:
        grads_pruned += [grad]
      else:
        idx_mask = self.maskable_var_names.index(grad[1].name)
        grads_pruned += [(grad[0] * self.masks[idx_mask], grad[1])]

    return grads_pruned

  def __choose_discr_chns(self):  # pylint: disable=too-many-locals
    """Choose discrimination-aware channels."""

    # select the most discriminative channels through multiple stages
    nb_workers = mgw.size() if FLAGS.enbl_multi_gpu else 1
    nb_iters_block = int(FLAGS.dcp_nb_iters_block / nb_workers)
    nb_iters_layer = int(FLAGS.dcp_nb_iters_layer / nb_workers)
    for idx_block in range(self.nb_blocks):
      # fine-tune the current block
      for idx_iter in range(nb_iters_block):
        if (idx_iter + 1) % FLAGS.summ_step != 0:
          self.sess_train.run(self.block_train_ops[idx_block])
        else:
          summary, __ = self.sess_train.run([self.summary_op, self.block_train_ops[idx_block]])
          if self.is_primary_worker('global'):
            self.sm_writer.add_summary(summary, nb_iters_block * idx_block + idx_iter)

      # select the most discriminative channels for each layer
      for idx_layer in range(1, self.nb_layers):  # do not prune the first layer
        if self.idxs_layer_to_block[idx_layer] != idx_block:
          continue

        # initialize the mask as all channels are pruned
        mask_shape = self.sess_train.run(tf.shape(self.masks[idx_layer]))
        tf.logging.info('layer #{}: mask\'s shape is {}'.format(idx_layer, mask_shape))
        nb_chns = mask_shape[2]
        grad_norm_mask = np.ones(nb_chns)
        mask_vec = np.sum(self.sess_train.run(self.masks[idx_layer]), axis=(0, 1, 3))
        prune_ratio = 1.0 - float(np.count_nonzero(mask_vec)) / mask_vec.size
        tf.logging.info('layer #%d: prune_ratio = %.4f' % (idx_layer, prune_ratio))
        is_first_entry = True
        while is_first_entry or prune_ratio > FLAGS.dcp_prune_ratio:
          # choose the most important channel and then update the mask
          grad_norm = self.sess_train.run(self.grad_norms[idx_layer])
          idx_chn_input = np.argmax(grad_norm * grad_norm_mask)
          grad_norm_mask[idx_chn_input] = 0.0
          tf.logging.info('adding channel #%d to the non-pruned set' % idx_chn_input)
          mask_delta = np.zeros(mask_shape)
          mask_delta[:, :, idx_chn_input, :] = 1.0
          if is_first_entry:
            is_first_entry = False
            self.sess_train.run(self.mask_init_ops[idx_layer])
          self.sess_train.run(self.mask_updt_ops[idx_layer],
                              feed_dict={self.mask_deltas[idx_layer]: mask_delta})
          self.sess_train.run(self.prune_ops[idx_layer])

          # fine-tune the current layer
          for idx_iter in range(nb_iters_layer):
            self.sess_train.run(self.layer_train_ops[idx_layer])

          # re-compute the pruning ratio
          mask_vec = np.sum(self.sess_train.run(self.masks[idx_layer]), axis=(0, 1, 3))
          prune_ratio = 1.0 - float(np.count_nonzero(mask_vec)) / mask_vec.size
          tf.logging.info('layer #%d: prune_ratio = %.4f' % (idx_layer, prune_ratio))

      # compute overall pruning ratios
      if self.is_primary_worker('global'):
        log_rslt = self.sess_train.run(self.log_op)
        log_str = ' | '.join(['%s = %.4e' % (name, value)
                              for name, value in zip(self.log_op_names, log_rslt)])
        tf.logging.info('block #%d: %s' % (idx_block + 1, log_str))

  def __save_model(self, is_train):
    """Save the current model for training or evaluation.

    Args:
    * is_train: whether to save a model for training
    """

    if is_train:
      save_path = self.saver_prnd_train.save(self.sess_train, FLAGS.dcp_save_path, self.global_step)
    else:
      save_path = self.saver_prnd_eval.save(self.sess_eval, FLAGS.dcp_save_path_eval)
    tf.logging.info('model saved to ' + save_path)

  def __restore_model(self, is_train):
    """Restore a model from the latest checkpoint files.

    Args:
    * is_train: whether to restore a model for training
    """

    save_path = tf.train.latest_checkpoint(os.path.dirname(FLAGS.dcp_save_path))
    if is_train:
      self.saver_prnd_train.restore(self.sess_train, save_path)
    else:
      self.saver_prnd_eval.restore(self.sess_eval, save_path)
    tf.logging.info('model restored from ' + save_path)

  def __monitor_progress(self, summary, log_rslt, idx_iter, time_step):
    """Monitor the training progress.

    Args:
    * summary: summary protocol buffer
    * log_rslt: logging operations' results
    * idx_iter: index of the training iteration
    * time_step: time step between two summary operations
    """

    # write summaries for TensorBoard visualization
    self.sm_writer.add_summary(summary, idx_iter)

    # compute the training speed
    speed = FLAGS.batch_size * FLAGS.summ_step / time_step
    if FLAGS.enbl_multi_gpu:
      speed *= mgw.size()

    # display monitored statistics
    log_str = ' | '.join(['%s = %.4e' % (name, value)
                          for name, value in zip(self.log_op_names, log_rslt)])
    tf.logging.info('iter #%d: %s | speed = %.2f pics / sec' % (idx_iter + 1, log_str, speed))
예제 #5
0
class ChannelPrunedGpuLearner(AbstractLearner):  # pylint: disable=too-many-instance-attributes
    """Channel pruning learner with GPU-based optimization."""
    def __init__(self, sm_writer, model_helper):
        """Constructor function.

    Args:
    * sm_writer: TensorFlow's summary writer
    * model_helper: model helper with definitions of model & dataset
    """

        # class-independent initialization
        super(ChannelPrunedGpuLearner, self).__init__(sm_writer, model_helper)

        # define scopes for full & channel-pruned models
        self.model_scope_full = 'model'
        self.model_scope_prnd = 'pruned_model'

        # download the pre-trained model
        if self.is_primary_worker('local'):
            self.download_model()  # pre-trained model is required
        self.auto_barrier()
        tf.logging.info('model files: ' + ', '.join(os.listdir('./models')))

        # class-dependent initialization
        if FLAGS.enbl_dst:
            self.helper_dst = DistillationHelper(sm_writer, model_helper,
                                                 self.mpi_comm)
        self.__build_train()
        self.__build_eval()

    def train(self):
        """Train a model and periodically produce checkpoint files."""

        # restore the full model from pre-trained checkpoints
        save_path = tf.train.latest_checkpoint(
            os.path.dirname(self.save_path_full))
        self.saver_full.restore(self.sess_train, save_path)

        # initialization
        self.sess_train.run([self.init_op, self.init_opt_op])
        self.sess_train.run(
            [layer_op['init_opt'] for layer_op in self.layer_ops])
        if FLAGS.enbl_multi_gpu:
            self.sess_train.run(self.bcast_op)

        # choose channels and evaluate the model before re-training
        self.__choose_channels()
        if self.is_primary_worker('global'):
            self.__save_model(is_train=True)
            self.evaluate()
        self.auto_barrier()

        # fine-tune the model with chosen channels only
        time_prev = timer()
        for idx_iter in range(self.nb_iters_train):
            # train the model
            if (idx_iter + 1) % FLAGS.summ_step != 0:
                self.sess_train.run(self.train_op)
            else:
                __, summary, log_rslt = self.sess_train.run(
                    [self.train_op, self.summary_op, self.log_op])
                if self.is_primary_worker('global'):
                    time_step = timer() - time_prev
                    self.__monitor_progress(summary, log_rslt, idx_iter,
                                            time_step)
                    time_prev = timer()

            # save the model at certain steps
            if self.is_primary_worker('global') and (idx_iter +
                                                     1) % FLAGS.save_step == 0:
                self.__save_model(is_train=True)
                self.evaluate()
            self.auto_barrier()

        # save the final model
        if self.is_primary_worker('global'):
            self.__save_model(is_train=True)
            self.__restore_model(is_train=False)
            self.__save_model(is_train=False)
            self.evaluate()

    def evaluate(self):
        """Restore a model from the latest checkpoint files and then evaluate it."""

        self.__restore_model(is_train=False)
        nb_iters = int(
            np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval))
        eval_rslts = np.zeros((nb_iters, len(self.eval_op)))
        for idx_iter in range(nb_iters):
            eval_rslts[idx_iter] = self.sess_eval.run(self.eval_op)
        for idx, name in enumerate(self.eval_op_names):
            tf.logging.info('%s = %.4e' % (name, np.mean(eval_rslts[:, idx])))

    def __build_train(self):  # pylint: disable=too-many-locals,too-many-statements
        """Build the training graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(sess, images)

            # model definition - full model
            with tf.variable_scope(self.model_scope_full):
                __ = self.forward_train(images)
                self.vars_full = get_vars_by_scope(self.model_scope_full)
                self.saver_full = tf.train.Saver(self.vars_full['all'])
                self.save_path_full = FLAGS.save_path

            # model definition - channel-pruned model
            with tf.variable_scope(self.model_scope_prnd):
                logits_prnd = self.forward_train(images)
                self.vars_prnd = get_vars_by_scope(self.model_scope_prnd)
                self.maskable_var_names = [
                    var.name for var in self.vars_prnd['maskable']
                ]
                self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all'])

                # loss & extra evaluation metrics
                loss, metrics = self.calc_loss(labels, logits_prnd,
                                               self.vars_prnd['trainable'])
                if FLAGS.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits_prnd, logits_dst)
                tf.summary.scalar('loss', loss)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # learning rate schedule
                self.global_step = tf.train.get_or_create_global_step()
                lrn_rate, self.nb_iters_train = self.setup_lrn_rate(
                    self.global_step)

                # overall pruning ratios of trainable & maskable variables
                pr_trainable = calc_prune_ratio(self.vars_prnd['trainable'])
                pr_maskable = calc_prune_ratio(self.vars_prnd['maskable'])
                tf.summary.scalar('pr_trainable', pr_trainable)
                tf.summary.scalar('pr_maskable', pr_maskable)

                # create masks and corresponding operations for channel pruning
                self.masks = []
                self.mask_updt_ops = []
                for idx, var in enumerate(self.vars_prnd['maskable']):
                    tf.logging.info(
                        'creating a pruning mask for {} of size {}'.format(
                            var.name, var.shape))
                    name = '/'.join(var.name.split('/')[1:]).replace(
                        ':0', '_mask')
                    self.masks += [
                        tf.get_variable(name,
                                        initializer=tf.ones(var.shape),
                                        trainable=False)
                    ]
                    var_norm = tf.sqrt(
                        tf.reduce_sum(tf.square(var),
                                      axis=[0, 1, 3],
                                      keepdims=True))
                    mask_vec = tf.cast(var_norm > 0.0, tf.float32)
                    mask_new = tf.tile(
                        mask_vec,
                        [var.shape[0], var.shape[1], 1, var.shape[3]])
                    self.mask_updt_ops += [self.masks[-1].assign(mask_new)]

                # build extra losses for regression & discrimination
                self.reg_losses = self.__build_extra_losses()
                self.nb_layers = len(self.reg_losses)
                for idx, reg_loss in enumerate(self.reg_losses):
                    tf.summary.scalar('reg_loss_%d' % idx, reg_loss)

                # obtain the full list of trainable variables & update operations
                self.vars_all = tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES, scope=self.model_scope_prnd)
                self.trainable_vars_all = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES,
                    scope=self.model_scope_prnd)
                self.update_ops_all = tf.get_collection(
                    tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_prnd)

                # TF operations for initializing the channel-pruned model
                init_ops = []
                with tf.control_dependencies(
                    [tf.variables_initializer(self.vars_all)]):
                    for var_full, var_prnd in zip(self.vars_full['all'],
                                                  self.vars_prnd['all']):
                        init_ops += [var_prnd.assign(var_full)]
                self.init_op = tf.group(init_ops)

                # TF operations for layer-wise, block-wise, and whole-network fine-tuning
                self.layer_ops, self.lrn_rates_pgd, self.prune_perctls = self.__build_layer_ops(
                )
                self.train_op, self.init_opt_op = self.__build_network_ops(
                    loss, lrn_rate)

            # TF operations for logging & summarizing
            self.sess_train = sess
            self.summary_op = tf.summary.merge_all()
            self.log_op = [lrn_rate, loss, pr_trainable, pr_maskable] + list(
                metrics.values())
            self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_msk'] + list(
                metrics.keys())
            if FLAGS.enbl_multi_gpu:
                self.bcast_op = mgw.broadcast_global_variables(0)

    def __build_eval(self):
        """Build the evaluation graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            self.sess_eval = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_eval()
                images, labels = iterator.get_next()

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(
                    self.sess_eval, images)

            # model definition - channel-pruned model
            with tf.variable_scope(self.model_scope_prnd):
                # loss & extra evaluation metrics
                logits = self.forward_eval(images)
                vars_prnd = get_vars_by_scope(self.model_scope_prnd)
                loss, metrics = self.calc_loss(labels, logits,
                                               vars_prnd['trainable'])
                if FLAGS.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits, logits_dst)

                # overall pruning ratios of trainable & maskable variables
                pr_trainable = calc_prune_ratio(vars_prnd['trainable'])
                pr_maskable = calc_prune_ratio(vars_prnd['maskable'])

                # TF operations for evaluation
                self.eval_op = [loss, pr_trainable, pr_maskable] + list(
                    metrics.values())
                self.eval_op_names = ['loss', 'pr_trn', 'pr_msk'] + list(
                    metrics.keys())
                self.saver_prnd_eval = tf.train.Saver(vars_prnd['all'])

            # add input & output tensors to certain collections
            tf.add_to_collection('images_final', images)
            tf.add_to_collection('logits_final', logits)

    def __build_extra_losses(self):
        """Build extra losses for regression.

    Returns:
    * reg_losses: list of regression losses (one per layer)
    """

        # insert additional losses to intermediate layers
        pattern = re.compile('Conv2D$')
        core_ops_full = get_ops_by_scope_n_pattern(self.model_scope_full,
                                                   pattern)
        core_ops_prnd = get_ops_by_scope_n_pattern(self.model_scope_prnd,
                                                   pattern)
        reg_losses = []
        for core_op_full, core_op_prnd in zip(core_ops_full, core_ops_prnd):
            reg_losses += [
                tf.nn.l2_loss(core_op_full.outputs[0] -
                              core_op_prnd.outputs[0])
            ]

        return reg_losses

    def __build_layer_ops(self):
        """Build layer-wise fine-tuning operations.

    Returns:
    * layer_ops: list of training and initialization operations for each layer
    * lrn_rates_pgd: list of layer-wise learning rate
    * prune_perctls: list of layer-wise pruning percentiles
    """

        layer_ops = []
        lrn_rates_pgd = []  # list of layer-wise learning rate
        prune_perctls = []  # list of layer-wise pruning percentiles
        for idx, var_prnd in enumerate(self.vars_prnd['maskable']):
            # create placeholders
            lrn_rate_pgd = tf.placeholder(tf.float32,
                                          shape=[],
                                          name='lrn_rate_pgd_%d' % idx)
            prune_perctl = tf.placeholder(tf.float32,
                                          shape=[],
                                          name='prune_perctl_%d' % idx)

            # select channels for the current convolutional layer
            optimizer = tf.train.GradientDescentOptimizer(lrn_rate_pgd)
            if FLAGS.enbl_multi_gpu:
                optimizer = mgw.DistributedOptimizer(optimizer)
            grads = optimizer.compute_gradients(self.reg_losses[idx],
                                                [var_prnd])
            with tf.control_dependencies(self.update_ops_all):
                var_prnd_new = var_prnd - lrn_rate_pgd * grads[0][0]
                var_norm = tf.sqrt(
                    tf.reduce_sum(tf.square(var_prnd_new),
                                  axis=[0, 1, 3],
                                  keepdims=True))
                threshold = tf.contrib.distributions.percentile(
                    var_norm, prune_perctl)
                shrk_vec = tf.maximum(1.0 - threshold / var_norm, 0.0)
                prune_op = var_prnd.assign(var_prnd_new * shrk_vec)

            # fine-tune with selected channels only
            optimizer_base = tf.train.AdamOptimizer(FLAGS.cpg_lrn_rate_adam)
            if not FLAGS.enbl_multi_gpu:
                optimizer = optimizer_base
            else:
                optimizer = mgw.DistributedOptimizer(optimizer_base)
            grads_origin = optimizer.compute_gradients(self.reg_losses[idx],
                                                       [var_prnd])
            grads_pruned = self.__calc_grads_pruned(grads_origin)
            with tf.control_dependencies(self.update_ops_all):
                finetune_op = optimizer.apply_gradients(grads_pruned)
            init_opt_op = tf.variables_initializer(optimizer_base.variables())

            # append layer-wise operations & variables
            layer_ops += [{
                'prune': prune_op,
                'finetune': finetune_op,
                'init_opt': init_opt_op
            }]
            lrn_rates_pgd += [lrn_rate_pgd]
            prune_perctls += [prune_perctl]

        return layer_ops, lrn_rates_pgd, prune_perctls

    def __build_network_ops(self, loss, lrn_rate):
        """Build network training operations.

    Returns:
    * train_op: training operation of the whole network
    * init_opt_op: initialization operation of the whole network's optimizer
    """

        optimizer_base = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
        if not FLAGS.enbl_multi_gpu:
            optimizer = optimizer_base
        else:
            optimizer = mgw.DistributedOptimizer(optimizer_base)
        grads_origin = optimizer.compute_gradients(loss,
                                                   self.trainable_vars_all)
        grads_pruned = self.__calc_grads_pruned(grads_origin)
        with tf.control_dependencies(self.update_ops_all):
            train_op = optimizer.apply_gradients(grads_pruned,
                                                 global_step=self.global_step)
        init_opt_op = tf.variables_initializer(optimizer_base.variables())

        return train_op, init_opt_op

    def __calc_grads_pruned(self, grads_origin):
        """Calculate the mask-pruned gradients.

    Args:
    * grads_origin: list of original gradients

    Returns:
    * grads_pruned: list of mask-pruned gradients
    """

        grads_pruned = []
        for grad in grads_origin:
            if grad[1].name not in self.maskable_var_names:
                grads_pruned += [grad]
            else:
                idx_mask = self.maskable_var_names.index(grad[1].name)
                grads_pruned += [(grad[0] * self.masks[idx_mask], grad[1])]

        return grads_pruned

    def __choose_channels(self):  # pylint: disable=too-many-locals
        """Choose channels for all convolutional layers."""

        # obtain each layer's pruning ratio
        if FLAGS.cpg_prune_ratio_type == 'uniform':
            ratio_list = [FLAGS.cpg_prune_ratio] * self.nb_layers
            if FLAGS.cpg_skip_ht_layers:
                ratio_list[0] = 0.0
                ratio_list[-1] = 0.0
        elif FLAGS.cpg_prune_ratio_type == 'list':
            with open(FLAGS.cpg_prune_ratio_file, 'r') as i_file:
                i_line = i_file.readline().strip()
                ratio_list = [float(sub_str) for sub_str in i_line.split(',')]
        else:
            raise ValueError('unrecognized pruning ratio type: ' +
                             FLAGS.cpg_prune_ratio_type)

        # select channels for all convolutional layers
        nb_workers = mgw.size() if FLAGS.enbl_multi_gpu else 1
        nb_iters_layer = int(FLAGS.cpg_nb_iters_layer / nb_workers)
        for idx_layer in range(self.nb_layers):
            # skip if no pruning is required
            if ratio_list[idx_layer] == 0.0:
                continue
            if self.is_primary_worker('global'):
                tf.logging.info('layer #%d: pr = %.2f (target)' %
                                (idx_layer, ratio_list[idx_layer]))
                tf.logging.info('mask.shape = {}'.format(
                    self.masks[idx_layer].shape))

            # select channels for the current convolutional layer
            time_prev = timer()
            reg_loss_prev = 0.0
            lrn_rate_pgd = FLAGS.cpg_lrn_rate_pgd_init
            for idx_iter in range(nb_iters_layer):
                # take a stochastic proximal gradient descent step
                prune_perctl = ratio_list[idx_layer] * 100.0 * (
                    idx_iter + 1) / nb_iters_layer
                __, reg_loss = self.sess_train.run(
                    [
                        self.layer_ops[idx_layer]['prune'],
                        self.reg_losses[idx_layer]
                    ],
                    feed_dict={
                        self.lrn_rates_pgd[idx_layer]: lrn_rate_pgd,
                        self.prune_perctls[idx_layer]: prune_perctl
                    })
                mask = self.sess_train.run(self.masks[idx_layer])
                if self.is_primary_worker('global'):
                    nb_chns_nnz = np.count_nonzero(np.sum(mask,
                                                          axis=(0, 1, 3)))
                    tf.logging.info(
                        'iter %d: nnz-chns = %d | loss = %.2e | lr = %.2e | percentile = %.2f'
                        % (idx_iter + 1, nb_chns_nnz, reg_loss, lrn_rate_pgd,
                           prune_perctl))

                # adjust the learning rate
                if reg_loss < reg_loss_prev:
                    lrn_rate_pgd *= FLAGS.cpg_lrn_rate_pgd_incr
                else:
                    lrn_rate_pgd *= FLAGS.cpg_lrn_rate_pgd_decr
                reg_loss_prev = reg_loss

            # fine-tune with selected channels only
            self.sess_train.run(self.mask_updt_ops[idx_layer])
            for idx_iter in range(nb_iters_layer):
                __, reg_loss = self.sess_train.run([
                    self.layer_ops[idx_layer]['finetune'],
                    self.reg_losses[idx_layer]
                ])
                mask = self.sess_train.run(self.masks[idx_layer])
                if self.is_primary_worker('global'):
                    nb_chns_nnz = np.count_nonzero(np.sum(mask,
                                                          axis=(0, 1, 3)))
                    tf.logging.info('iter %d: nnz-chns = %d | loss = %.2e' %
                                    (idx_iter + 1, nb_chns_nnz, reg_loss))

            # re-compute the pruning ratio
            mask_vec = np.mean(np.square(
                self.sess_train.run(self.masks[idx_layer])),
                               axis=(0, 1, 3))
            prune_ratio = 1.0 - float(
                np.count_nonzero(mask_vec)) / mask_vec.size
            if self.is_primary_worker('global'):
                tf.logging.info('layer #%d: pr = %.2f (actual) | time = %.2f' %
                                (idx_layer, prune_ratio, timer() - time_prev))

        # compute overall pruning ratios
        if self.is_primary_worker('global'):
            log_rslt = self.sess_train.run(self.log_op)
            log_str = ' | '.join([
                '%s = %.4e' % (name, value)
                for name, value in zip(self.log_op_names, log_rslt)
            ])

    def __save_model(self, is_train):
        """Save the current model for training or evaluation.

    Args:
    * is_train: whether to save a model for training
    """

        if is_train:
            save_path = self.saver_prnd_train.save(self.sess_train,
                                                   FLAGS.cpg_save_path,
                                                   self.global_step)
        else:
            save_path = self.saver_prnd_eval.save(self.sess_eval,
                                                  FLAGS.cpg_save_path_eval)
        tf.logging.info('model saved to ' + save_path)

    def __restore_model(self, is_train):
        """Restore a model from the latest checkpoint files.

    Args:
    * is_train: whether to restore a model for training
    """

        save_path = tf.train.latest_checkpoint(
            os.path.dirname(FLAGS.cpg_save_path))
        if is_train:
            self.saver_prnd_train.restore(self.sess_train, save_path)
        else:
            self.saver_prnd_eval.restore(self.sess_eval, save_path)
        tf.logging.info('model restored from ' + save_path)

    def __monitor_progress(self, summary, log_rslt, idx_iter, time_step):
        """Monitor the training progress.

    Args:
    * summary: summary protocol buffer
    * log_rslt: logging operations' results
    * idx_iter: index of the training iteration
    * time_step: time step between two summary operations
    """

        # write summaries for TensorBoard visualization
        self.sm_writer.add_summary(summary, idx_iter)

        # compute the training speed
        speed = FLAGS.batch_size * FLAGS.summ_step / time_step
        if FLAGS.enbl_multi_gpu:
            speed *= mgw.size()

        # display monitored statistics
        log_str = ' | '.join([
            '%s = %.4e' % (name, value)
            for name, value in zip(self.log_op_names, log_rslt)
        ])
        tf.logging.info('iter #%d: %s | speed = %.2f pics / sec' %
                        (idx_iter + 1, log_str, speed))
예제 #6
0
class NonUniformQuantLearner(AbstractLearner):
    # pylint: disable=too-many-instance-attributes
    '''
  Nonuniform quantization for weights and activations
  '''
    def __init__(self, sm_writer, model_helper):
        # class-independent initialization
        super(NonUniformQuantLearner, self).__init__(sm_writer, model_helper)

        # class-dependent initialization
        if FLAGS.enbl_dst:
            self.helper_dst = DistillationHelper(sm_writer, model_helper,
                                                 self.mpi_comm)

        # initialize class attributes
        self.ops = {}
        self.bit_placeholders = {}
        self.statistics = {}

        self.__build_train()  # for train
        self.__build_eval()  # for eval

        if self.is_primary_worker('local'):
            self.download_model()
        self.auto_barrier()

        # determine the optimal policy.
        bit_optimizer = BitOptimizer(self.dataset_name, self.weights,
                                     self.statistics, self.bit_placeholders,
                                     self.ops, self.layerwise_tune_list,
                                     self.sess_train, self.sess_eval,
                                     self.saver_train, self.saver_quant,
                                     self.saver_eval, self.auto_barrier)
        self.optimal_w_bit_list, self.optimal_a_bit_list = bit_optimizer.run()
        self.auto_barrier()

    def train(self):
        # initialization
        self.sess_train.run(self.ops['non_cluster_init'])
        # mgw_size = int(mgw.size()) if FLAGS.enbl_multi_gpu else 1

        total_iters = self.finetune_steps
        if FLAGS.enbl_warm_start:
            self.__restore_model(
                is_train=True)  # use the latest model for warm start

        # NOTE: initialize the clusters after restore weights
        self.sess_train.run(self.ops['cluster_init'], \
            feed_dict={self.bit_placeholders['w_train']: self.optimal_w_bit_list})

        feed_dict = {self.bit_placeholders['w_train']: self.optimal_w_bit_list, \
            self.bit_placeholders['a_train']: self.optimal_a_bit_list}

        if FLAGS.enbl_multi_gpu:
            self.sess_train.run(self.ops['bcast'])

        time_prev = timer()

        for idx_iter in range(total_iters):
            # train the model
            if (idx_iter + 1) % FLAGS.summ_step != 0:
                self.sess_train.run(self.ops['train'], feed_dict=feed_dict)
            else:
                _, summary, log_rslt = self.sess_train.run([self.ops['train'], \
                    self.ops['summary'], self.ops['log']], feed_dict=feed_dict)
                time_prev = self.__monitor_progress(summary, log_rslt,
                                                    time_prev, idx_iter)

            # save & evaluate the model at certain steps
            if (idx_iter + 1) % FLAGS.save_step == 0:
                self.__save_model()
                self.evaluate()
                tf.logging.info("Optimal Weight Quantization:{}".format(
                    self.optimal_w_bit_list))
                self.auto_barrier()

        # save the final model
        self.__save_model()
        self.evaluate()

    def evaluate(self):
        # early break for non-primary workers
        if not self.is_primary_worker():
            return

        # evaluate the model
        self.__restore_model(is_train=False)
        losses, accuracies = [], []
        nb_iters = int(
            np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval))

        # build the quantization bits
        feed_dict = {self.bit_placeholders['w_eval']: self.optimal_w_bit_list, \
            self.bit_placeholders['a_eval']: self.optimal_a_bit_list}

        for _ in range(nb_iters):
            eval_rslt = self.sess_eval.run(self.ops['eval'],
                                           feed_dict=feed_dict)
            losses.append(eval_rslt[0])
            accuracies.append(eval_rslt[1])
        tf.logging.info('loss: {}'.format(np.mean(np.array(losses))))
        tf.logging.info('accuracy: {}'.format(np.mean(np.array(accuracies))))
        tf.logging.info("Optimal Weight Quantization:{}".format(
            self.optimal_w_bit_list))

        if FLAGS.nuql_use_buckets:
            bucket_storage = self.sess_eval.run(self.ops['bucket_storage'],
                                                feed_dict=feed_dict)
            self.__show_bucket_storage(bucket_storage)

    def __build_train(self):
        with tf.Graph().as_default():
            # TensorFlow session
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(mgw.local_rank() \
                if FLAGS.enbl_multi_gpu else 0)
            self.sess_train = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()
                images.set_shape((FLAGS.batch_size, images.shape[1], images.shape[2], \
                    images.shape[3]))

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(
                    self.sess_train, images)

            # model definition
            with tf.variable_scope(self.model_scope, reuse=tf.AUTO_REUSE):
                # forward pass
                logits = self.forward_train(images)

                self.saver_train = tf.train.Saver(self.vars)

                self.weights = [
                    v for v in self.trainable_vars
                    if 'kernel' in v.name or 'weight' in v.name
                ]
                if not FLAGS.nuql_quantize_all_layers:
                    self.weights = self.weights[1:-1]
                self.statistics['num_weights'] = \
                    [tf.reshape(v, [-1]).shape[0].value for v in self.weights]

                self.__quantize_train_graph()

                # loss & accuracy
                # Strictly speaking, clusters should be not included for regularization.
                loss, metrics = self.calc_loss(labels, logits,
                                               self.trainable_vars)
                if self.dataset_name == 'cifar_10':
                    acc_top1, acc_top5 = metrics['accuracy'], tf.constant(0.)
                elif self.dataset_name == 'ilsvrc_12':
                    acc_top1, acc_top5 = metrics['acc_top1'], metrics[
                        'acc_top5']
                else:
                    raise ValueError("Unrecognized dataset name")

                model_loss = loss
                if FLAGS.enbl_dst:
                    dst_loss = self.helper_dst.calc_loss(logits, logits_dst)
                    loss += dst_loss
                    tf.summary.scalar('dst_loss', dst_loss)
                tf.summary.scalar('model_loss', model_loss)
                tf.summary.scalar('loss', loss)
                tf.summary.scalar('acc_top1', acc_top1)
                tf.summary.scalar('acc_top5', acc_top5)

                self.ft_step = tf.get_variable('finetune_step',
                                               shape=[],
                                               dtype=tf.int32,
                                               trainable=False)

            # optimizer & gradients
            init_lr, bnds, decay_rates, self.finetune_steps = \
                setup_bnds_decay_rates(self.model_name, self.dataset_name)
            lrn_rate = tf.train.piecewise_constant(self.ft_step, [i for i in bnds], \
                [init_lr * decay_rate for decay_rate in decay_rates])

            # optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
            optimizer = tf.train.AdamOptimizer(lrn_rate)
            if FLAGS.enbl_multi_gpu:
                optimizer = mgw.DistributedOptimizer(optimizer)

            # acquire non-cluster-vars and clusters.
            clusters = [v for v in self.trainable_vars if 'clusters' in v.name]
            rest_trainable_vars = [v for v in self.trainable_vars \
                if v not in clusters]

            # determine the var_list optimize
            if FLAGS.nuql_opt_mode in ['cluster', 'both']:
                if FLAGS.nuql_opt_mode == 'both':
                    optimizable_vars = self.trainable_vars
                else:
                    optimizable_vars = clusters
                if FLAGS.nuql_enbl_rl_agent:
                    optimizer_fintune = tf.train.GradientDescentOptimizer(
                        lrn_rate)
                    if FLAGS.enbl_multi_gpu:
                        optimizer_fintune = mgw.DistributedOptimizer(
                            optimizer_fintune)
                    grads_fintune = optimizer_fintune.compute_gradients(
                        loss, var_list=optimizable_vars)

            elif FLAGS.nuql_opt_mode == 'weights':
                optimizable_vars = rest_trainable_vars
            else:
                raise ValueError("Unknown optimization mode")

            grads = optimizer.compute_gradients(loss,
                                                var_list=optimizable_vars)

            # sm write graph
            self.sm_writer.add_graph(self.sess_train.graph)

            # define the ops
            with tf.control_dependencies(self.update_ops):
                self.ops['train'] = optimizer.apply_gradients(
                    grads, global_step=self.ft_step)
                if FLAGS.nuql_opt_mode in ['both', 'cluster'
                                           ] and FLAGS.nuql_enbl_rl_agent:
                    self.ops['rl_fintune'] = optimizer_fintune.apply_gradients(
                        grads_fintune, global_step=self.ft_step)
                else:
                    self.ops['rl_fintune'] = self.ops['train']
            self.ops['summary'] = tf.summary.merge_all()
            if FLAGS.enbl_dst:
                self.ops['log'] = [
                    lrn_rate, dst_loss, model_loss, loss, acc_top1, acc_top5
                ]
            else:
                self.ops['log'] = [
                    lrn_rate, model_loss, loss, acc_top1, acc_top5
                ]

            # NOTE: first run non_cluster_init_op, then cluster_init_op
            cluster_global_vars = [
                v for v in self.vars if 'clusters' in v.name
            ]
            # pay attention to beta1_power, beta2_power.
            non_cluster_global_vars = [
                v for v in tf.global_variables()
                if v not in cluster_global_vars
            ]

            self.ops['non_cluster_init'] = tf.variables_initializer(
                var_list=non_cluster_global_vars)
            self.ops['cluster_init'] = tf.variables_initializer(
                var_list=cluster_global_vars)
            self.ops['bcast'] = mgw.broadcast_global_variables(0) \
                if FLAGS.enbl_multi_gpu else None
            self.ops['reset_ft_step'] = tf.assign(self.ft_step, \
                tf.constant(0, dtype=tf.int32))

            self.saver_quant = tf.train.Saver(self.vars)

    def __build_eval(self):
        with tf.Graph().as_default():
            # TensorFlow session
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(mgw.local_rank() \
                if FLAGS.enbl_multi_gpu else 0)
            self.sess_eval = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_eval()
                images, labels = iterator.get_next()
                images.set_shape((FLAGS.batch_size, images.shape[1], images.shape[2], \
                    images.shape[3]))
                self.images_eval = images

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(
                    self.sess_eval, images)

            # model definition
            with tf.variable_scope(self.model_scope, reuse=tf.AUTO_REUSE):
                # forward pass
                logits = self.forward_eval(images)

                self.__quantize_eval_graph()

                # loss & accuracy
                loss, metrics = self.calc_loss(labels, logits,
                                               self.trainable_vars)
                if self.dataset_name == 'cifar_10':
                    acc_top1, acc_top5 = metrics['accuracy'], tf.constant(0.)
                elif self.dataset_name == 'ilsvrc_12':
                    acc_top1, acc_top5 = metrics['acc_top1'], metrics[
                        'acc_top5']
                else:
                    raise ValueError("Unrecognized dataset name")

                if FLAGS.enbl_dst:
                    dst_loss = self.helper_dst.calc_loss(logits, logits_dst)
                    loss += dst_loss

            self.ops['eval'] = [loss, acc_top1, acc_top5]
            self.saver_eval = tf.train.Saver(self.vars)

    def __quantize_train_graph(self):
        """ Insert Quantization nodes to the training graph """
        nonuni_quant = NonUniformQuantization(self.sess_train,
                                              FLAGS.nuql_bucket_size,
                                              FLAGS.nuql_use_buckets,
                                              FLAGS.nuql_init_style,
                                              FLAGS.nuql_bucket_type)

        # Find Matmul Op and activation Op
        matmul_ops = nonuni_quant.search_matmul_op(
            FLAGS.nuql_quantize_all_layers)
        act_ops = nonuni_quant.search_activation_op()

        self.statistics['nb_matmuls'] = len(matmul_ops)
        self.statistics['nb_activations'] = len(act_ops)

        # Replace Conv2d Op with quantized weights
        matmul_op_names = [op.name for op in matmul_ops]
        act_op_names = [op.name for op in act_ops]

        # build the placeholder for nonuniform quantization bits.
        self.bit_placeholders['w_train'] = tf.placeholder(tf.int64, \
            shape=[self.statistics['nb_matmuls']], name="w_bit_list")
        self.bit_placeholders['a_train'] = tf.placeholder(tf.int64, \
            shape=[self.statistics['nb_activations']], name="a_bit_list")
        w_bit_dict_train = self.__build_quant_dict(matmul_op_names, \
            self.bit_placeholders['w_train'])
        a_bit_dict_train = self.__build_quant_dict(act_op_names, \
            self.bit_placeholders['a_train'])

        # Insert Quant Op for weights and activations
        nonuni_quant.insert_quant_op_for_weights(w_bit_dict_train)
        # NOTE: Not necessary for activation quantization in non-uniform
        nonuni_quant.insert_quant_op_for_activations(a_bit_dict_train)

        # TODO: add layerwise finetuning. working not very weill
        self.layerwise_tune_list = nonuni_quant.get_layerwise_tune_op(self.weights) \
            if FLAGS.nuql_enbl_rl_layerwise_tune else (None, None)

    def __quantize_eval_graph(self):
        """ Insert Quantization nodes to the eval graph """
        nonuni_quant = NonUniformQuantization(self.sess_eval,
                                              FLAGS.nuql_bucket_size,
                                              FLAGS.nuql_use_buckets,
                                              FLAGS.nuql_init_style,
                                              FLAGS.nuql_bucket_type)

        # Find Matmul Op and activation Op
        matmul_ops = nonuni_quant.search_matmul_op(
            FLAGS.nuql_quantize_all_layers)
        act_ops = nonuni_quant.search_activation_op()
        assert self.statistics['nb_matmuls'] == len(matmul_ops), \
            'the length of matmul_ops does not match'
        assert self.statistics['nb_activations'] == len(act_ops), \
            'the length of act_ops does not match'

        # Replace Conv2d Op with quantized weights
        matmul_op_names = [op.name for op in matmul_ops]
        act_op_names = [op.name for op in act_ops]

        # build the placeholder for eval
        self.bit_placeholders['w_eval'] = tf.placeholder(tf.int64, \
            shape=[self.statistics['nb_matmuls']], name="w_bit_list")
        self.bit_placeholders['a_eval'] = tf.placeholder(tf.int64, \
            shape=[self.statistics['nb_activations']], name="a_bit_list")

        w_bit_dict_eval = self.__build_quant_dict(matmul_op_names, \
            self.bit_placeholders['w_eval'])
        a_bit_dict_eval = self.__build_quant_dict(act_op_names, \
            self.bit_placeholders['a_eval'])

        # Insert Quant Op for weights and activations
        nonuni_quant.insert_quant_op_for_weights(w_bit_dict_eval)
        # NOTE: no need for activation quantization
        nonuni_quant.insert_quant_op_for_activations(a_bit_dict_eval)

        self.ops['bucket_storage'] = nonuni_quant.bucket_storage if FLAGS.nuql_use_buckets \
            else tf.constant(0, tf.int32)

    def __save_model(self):
        # early break for non-primary workers
        if not self.is_primary_worker():
            return
        # save quantization model
        save_quant_model_path = self.saver_quant.save(self.sess_train, \
            FLAGS.nuql_save_quant_model_path, self.ft_step)
        tf.logging.info('quantized model saved to ' + save_quant_model_path)

    def __restore_model(self, is_train):
        if is_train:
            save_path = tf.train.latest_checkpoint(
                os.path.dirname(FLAGS.save_path))
            self.saver_train.restore(self.sess_train, save_path)
        else:
            save_path = tf.train.latest_checkpoint(
                os.path.dirname(FLAGS.nuql_save_quant_model_path))
            self.saver_eval.restore(self.sess_eval, save_path)
        tf.logging.info('model restored from ' + save_path)

    def __monitor_progress(self, summary, log_rslt, time_prev, idx_iter):
        # early break for non-primary workers
        if not self.is_primary_worker():
            return None

        # write summaries for TensorBoard visualization
        self.sm_writer.add_summary(summary, idx_iter)

        # display monitored statistics
        speed = FLAGS.batch_size * FLAGS.summ_step / (timer() - time_prev)
        if FLAGS.enbl_multi_gpu:
            speed *= mgw.size()

        if FLAGS.enbl_dst:
            lrn_rate, dst_loss, model_loss, loss, acc_top1, acc_top5 = log_rslt[0], \
            log_rslt[1], log_rslt[2], log_rslt[3], log_rslt[4], log_rslt[5]
            tf.logging.info('iter #%d: lr = %e | dst_loss = %.4f | model_loss = %.4f | loss = %.4f | acc_top1 = %.4f | acc_top5 = %.4f | speed = %.2f pics / sec' \
                % (idx_iter + 1, lrn_rate, dst_loss, model_loss, loss, acc_top1, acc_top5, speed))
        else:
            lrn_rate, model_loss, loss, acc_top1, acc_top5 = log_rslt[0], \
            log_rslt[1], log_rslt[2], log_rslt[3], log_rslt[4]
            tf.logging.info('iter #%d: lr = %e | model_loss = %.4f | loss = %.4f | acc_top1 = %.4f | acc_top5 = %.4f | speed = %.2f pics / sec' \
                % (idx_iter + 1, lrn_rate, model_loss, loss, acc_top1, acc_top5, speed))

        return timer()

    def __show_bucket_storage(self, bucket_storage):
        weight_storage = sum(self.statistics['num_weights']) * FLAGS.nuql_weight_bits \
            if not FLAGS.nuql_enbl_rl_agent \
            else sum(self.statistics['num_weights']) * FLAGS.nuql_equivalent_bits
        tf.logging.info('bucket storage: %d bit / %.3f kb | weight storage: %d bit / %.3f kb | ratio: %.3f' % \
            (bucket_storage, bucket_storage/(8.*1024.), weight_storage, \
                weight_storage/(8.*1024.), bucket_storage*1./weight_storage))

    @staticmethod
    def __build_quant_dict(keys, values):
        """ Bind keys and values to dictionaries.

    Args:
    * keys: A list of op_names;
    * values: A Tensor with len(keys) elements;

    Returns:
    * dict: (key, value) for weight name and quant bits respectively.
    """

        dict_ = {}
        for (idx, v) in enumerate(keys):
            dict_[v] = values[idx]
        return dict_
예제 #7
0
class ChannelPrunedLearner(AbstractLearner):  # pylint: disable=too-many-instance-attributes
  """Learner with channel/filter pruning"""

  def __init__(self, sm_writer, model_helper):
    # class-independent initialization
    super(ChannelPrunedLearner, self).__init__(sm_writer, model_helper)

    # class-dependent initialization
    if FLAGS.enbl_dst:
      self.learner_dst = DistillationHelper(sm_writer, model_helper, self.mpi_comm)

    self.model_scope = 'model'

    self.sm_writer = sm_writer
    #self.max_eval_acc = 0
    self.max_save_path = ''
    self.saver = None
    self.saver_train = None
    self.saver_eval = None
    self.model = None
    self.pruner = None
    self.sess_train = None
    self.sess_eval = None
    self.log_op = None
    self.train_op = None
    self.bcast_op = None
    self.train_init_op = None
    self.time_prev = None
    self.agent = None
    self.idx_iter = None
    self.accuracy_keys = None
    self.eval_op = None
    self.global_step = None
    self.summary_op = None
    self.nb_iters_train = 0
    self.bestinfo = None

    self.__build(is_train=True)
    self.__build(is_train=False)

  def train(self):
    """Train the pruned model"""
    # download pre-trained model
    if self.__is_primary_worker():
      self.download_model()
      self.__restore_model(True)
      self.saver_train.save(self.sess_train, FLAGS.cp_original_path)
      self.create_pruner()

    if FLAGS.enbl_multi_gpu:
      self.mpi_comm.Barrier()

    tf.logging.info('Start pruning')

    # channel pruning and finetuning
    if FLAGS.cp_prune_option == 'list':
      self.__prune_and_finetune_list()
    elif FLAGS.cp_prune_option == 'auto':
      self.__prune_and_finetune_auto()
    elif FLAGS.cp_prune_option == 'uniform':
      self.__prune_and_finetune_uniform()

  def create_pruner(self):
    """create a pruner"""
    with tf.Graph().as_default():
      config = tf.ConfigProto()
      config.gpu_options.visible_device_list = str(0) # pylint: disable=no-member
      sess = tf.Session(config=config)
      self.saver = tf.train.import_meta_graph(FLAGS.cp_original_path + '.meta')
      self.saver.restore(sess, FLAGS.cp_original_path)
      self.sess_train = sess
      self.sm_writer.add_graph(sess.graph)
      train_images = tf.get_collection('train_images')[0]
      train_labels = tf.get_collection('train_labels')[0]
      mem_images = tf.get_collection('mem_images')[0]
      mem_labels = tf.get_collection('mem_labels')[0]
      summary_op = tf.get_collection('summary_op')[0]
      loss = tf.get_collection('loss')[0]

      accuracy = tf.get_collection('accuracy')[0]
      #accuracy1 = tf.get_collection('top1')[0]
      #metrics = {'loss': loss, 'accuracy': accuracy['top1']}
      metrics = {'loss': loss, 'accuracy': accuracy}
      for key in self.accuracy_keys:
        metrics[key] = tf.get_collection(key)[0]
      self.model = Model(self.sess_train)
      pruner = ChannelPruner(
        self.model,
        images=train_images,
        labels=train_labels,
        mem_images=mem_images,
        mem_labels=mem_labels,
        metrics=metrics,
        lbound=self.lbound,
        summary_op=summary_op,
        sm_writer=self.sm_writer)

      self.pruner = pruner

  def evaluate(self):
    """evaluate the model"""
    # early break for non-primary workers
    if not self.__is_primary_worker():
      return

    if self.saver_eval is None:
      self.saver_eval = tf.train.Saver()
    self.__restore_model(is_train=False)
    losses, accuracy = [], []

    nb_iters = FLAGS.nb_smpls_eval // FLAGS.batch_size_eval

    self.sm_writer.add_graph(self.sess_eval.graph)

    accuracies = [[] for i in range(len(self.accuracy_keys))]
    for _ in range(nb_iters):
      eval_rslt = self.sess_eval.run(self.eval_op)
      losses.append(eval_rslt[0])
      for i in range(len(self.accuracy_keys)):
        accuracies[i].append(eval_rslt[i + 1])
    loss = np.mean(np.array(losses))
    tf.logging.info('loss: {}'.format(loss))
    for i in range(len(self.accuracy_keys)):
      accuracy.append(np.mean(np.array(accuracies[i])))
      tf.logging.info('{}: {}'.format(self.accuracy_keys[i], accuracy[i]))

    # save the checkpoint if its evaluatin result is best so far
    #if accuracy[0] > self.max_eval_acc:
    #  self.max_eval_acc = accuracy[0]
    #  self.__save_in_progress_pruned_model()

  def __build(self, is_train): # pylint: disable=too-many-locals
    # early break for non-primary workers
    if not self.__is_primary_worker():
      return

    if not is_train:
      self.__build_pruned_evaluate_model()
      return

    with tf.Graph().as_default():
      # create a TF session for the current graph
      config = tf.ConfigProto()
      config.gpu_options.visible_device_list = str(0) # pylint: disable=no-member
      sess = tf.Session(config=config)

      # data input pipeline
      with tf.variable_scope(self.data_scope):
        train_images, train_labels = self.build_dataset_train().get_next()
        eval_images, eval_labels = self.build_dataset_eval().get_next()
        image_shape = train_images.shape.as_list()
        label_shape = train_labels.shape.as_list()
        image_shape[0] = FLAGS.batch_size
        label_shape[0] = FLAGS.batch_size

        mem_images = tf.placeholder(dtype=train_images.dtype,
                                    shape=image_shape)
        mem_labels = tf.placeholder(dtype=train_labels.dtype,
                                    shape=label_shape)

        tf.add_to_collection('train_images', train_images)
        tf.add_to_collection('train_labels', train_labels)
        tf.add_to_collection('eval_images', eval_images)
        tf.add_to_collection('eval_labels', eval_labels)
        tf.add_to_collection('mem_images', mem_images)
        tf.add_to_collection('mem_labels', mem_labels)

      # model definition
      with tf.variable_scope(self.model_scope):
        # forward pass
        logits = self.forward_train(mem_images)
        loss, accuracy = self.calc_loss(mem_labels, logits, self.trainable_vars)
        self.accuracy_keys = list(accuracy.keys())
        for key in self.accuracy_keys:
          tf.add_to_collection(key, accuracy[key])
        tf.add_to_collection('loss', loss)
        tf.add_to_collection('logits', logits)

        #self.loss = loss
        tf.summary.scalar('loss', loss)
        for key in accuracy.keys():
          tf.summary.scalar(key, accuracy[key])

      # learning rate & pruning ratio
      self.sess_train = sess
      self.summary_op = tf.summary.merge_all()
      tf.add_to_collection('summary_op', self.summary_op)
      self.saver_train = tf.train.Saver(self.vars)

      self.lbound = math.log(FLAGS.cp_preserve_ratio + 1, 10) * 1.5
      self.rbound = 1.0

  def __build_pruned_evaluate_model(self, path=None):
    ''' build a evaluation model from pruned model '''
    # early break for non-primary workers
    if not self.__is_primary_worker():
      return

    if path is None:
      path = FLAGS.save_path

    if not tf.train.checkpoint_exists(path):
      return

    with tf.Graph().as_default():
      config = tf.ConfigProto()
      config.gpu_options.visible_device_list = str(# pylint: disable=no-member
        mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)
      self.sess_eval = tf.Session(config=config)
      self.saver_eval = tf.train.import_meta_graph(path + '.meta')
      self.saver_eval.restore(self.sess_eval, path)
      eval_logits = tf.get_collection('logits')[0]
      tf.add_to_collection('logits_final', eval_logits)
      eval_images = tf.get_collection('eval_images')[0]
      tf.add_to_collection('images_final', eval_images)
      eval_labels = tf.get_collection('eval_labels')[0]
      mem_images = tf.get_collection('mem_images')[0]
      mem_labels = tf.get_collection('mem_labels')[0]

      self.sess_eval.close()

      graph_editor.reroute_ts(eval_images, mem_images)
      graph_editor.reroute_ts(eval_labels, mem_labels)

      self.sess_eval = tf.Session(config=config)
      self.saver_eval.restore(self.sess_eval, path)
      trainable_vars = self.trainable_vars
      loss, accuracy = self.calc_loss(eval_labels, eval_logits, trainable_vars)
      self.eval_op = [loss] + list(accuracy.values())
      self.sm_writer.add_graph(self.sess_eval.graph)

  def __build_pruned_train_model(self, path=None, finetune=False): # pylint: disable=too-many-locals
    ''' build a training model from pruned model '''
    if path is None:
      path = FLAGS.save_path

    with tf.Graph().as_default():
      config = tf.ConfigProto()
      config.gpu_options.visible_device_list = str(# pylint: disable=no-member
        mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)
      self.sess_train = tf.Session(config=config)
      self.saver_train = tf.train.import_meta_graph(path + '.meta')
      self.saver_train.restore(self.sess_train, path)
      logits = tf.get_collection('logits')[0]
      train_images = tf.get_collection('train_images')[0]
      train_labels = tf.get_collection('train_labels')[0]
      mem_images = tf.get_collection('mem_images')[0]
      mem_labels = tf.get_collection('mem_labels')[0]

      self.sess_train.close()

      graph_editor.reroute_ts(train_images, mem_images)
      graph_editor.reroute_ts(train_labels, mem_labels)

      self.sess_train = tf.Session(config=config)
      self.saver_train.restore(self.sess_train, path)

      trainable_vars = self.trainable_vars
      loss, accuracy = self.calc_loss(train_labels, logits, trainable_vars)
      self.accuracy_keys = list(accuracy.keys())

      if FLAGS.enbl_dst:
        logits_dst = self.learner_dst.calc_logits(self.sess_train, train_images)
        loss += self.learner_dst.calc_loss(logits, logits_dst)

      tf.summary.scalar('loss', loss)
      for key in accuracy.keys():
        tf.summary.scalar(key, accuracy[key])
      self.summary_op = tf.summary.merge_all()

      global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32, trainable=False)
      self.global_step = global_step
      lrn_rate, self.nb_iters_train = setup_lrn_rate(
        self.global_step, self.model_name, self.dataset_name)

      if finetune and not FLAGS.cp_retrain:
        mom_optimizer = tf.train.AdamOptimizer(FLAGS.cp_lrn_rate_ft)
        self.log_op = [tf.constant(FLAGS.cp_lrn_rate_ft), loss, list(accuracy.values())]
      else:
        mom_optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
        self.log_op = [lrn_rate, loss, list(accuracy.values())]

      if FLAGS.enbl_multi_gpu:
        optimizer = mgw.DistributedOptimizer(mom_optimizer)
      else:
        optimizer = mom_optimizer
      grads_origin = optimizer.compute_gradients(loss, trainable_vars)
      grads_pruned, masks = self.__calc_grads_pruned(grads_origin)


      with tf.control_dependencies(self.update_ops):
        self.train_op = optimizer.apply_gradients(grads_pruned, global_step=global_step)

      self.sm_writer.add_graph(tf.get_default_graph())
      self.train_init_op = \
        tf.initialize_variables(mom_optimizer.variables() + [global_step] + masks)

      if FLAGS.enbl_multi_gpu:
        self.bcast_op = mgw.broadcast_global_variables(0)

  def __calc_grads_pruned(self, grads_origin):
    """Calculate the pruned gradients
    Args:
    * grads_origin: the original gradient

    Return:
    * the pruned gradients
    * the corresponding mask of the pruned gradients
    """
    grads_pruned = []
    masks = []
    maskable_var_names = {}
    fake_pruning_dict = {}
    if self.__is_primary_worker():
      fake_pruning_dict = self.pruner.fake_pruning_dict
      maskable_var_names = {
        self.pruner.model.get_var_by_op(
          self.pruner.model.g.get_operation_by_name(op_name)).name: \
            op_name for op_name, ratio in fake_pruning_dict.items()}
      tf.logging.debug('maskable var names {}'.format(maskable_var_names))

    if FLAGS.enbl_multi_gpu:
      fake_pruning_dict = self.mpi_comm.bcast(fake_pruning_dict, root=0)
      maskable_var_names = self.mpi_comm.bcast(maskable_var_names, root=0)

    for grad in grads_origin:
      if grad[1].name not in maskable_var_names.keys():
        grads_pruned.append(grad)
      else:
        pruned_idxs = fake_pruning_dict[maskable_var_names[grad[1].name]]
        mask_tensor = np.ones(grad[0].shape)
        mask_tensor[:, :, [not i for i in pruned_idxs[0]], :] = 0
        mask_tensor[:, :, :, [not i for i in pruned_idxs[1]]] = 0
        mask_initializer = tf.constant_initializer(mask_tensor)
        mask = tf.get_variable(
          grad[1].name.split(':')[0] + '_mask',
          shape=mask_tensor.shape, initializer=mask_initializer, trainable=False)
        masks.append(mask)
        grads_pruned.append((grad[0] * mask, grad[1]))

    return grads_pruned, masks

  def __train_pruned_model(self, finetune=False):
    """Train pruned model"""
    # Initialize varialbes
    self.sess_train.run(self.train_init_op)

    if FLAGS.enbl_multi_gpu:
      self.sess_train.run(self.bcast_op)

    ## Fintuning & distilling
    self.time_prev = timer()

    nb_iters = int(FLAGS.cp_nb_iters_ft_ratio * self.nb_iters_train) \
      if finetune and not FLAGS.cp_retrain else self.nb_iters_train

    for self.idx_iter in range(nb_iters):
      # train the model
      if (self.idx_iter + 1) % FLAGS.summ_step != 0:
        self.sess_train.run(self.train_op)
      else:
        __, summary, log_rslt = self.sess_train.run([self.train_op, self.summary_op, self.log_op])
        self.__monitor_progress(summary, log_rslt)

      # save the model at certain steps
      if (self.idx_iter + 1) % FLAGS.save_step == 0:
        #summary, log_rslt = self.sess_train.run([self.summary_op, self.log_op])
        #self.__monitor_progress(summary, log_rslt)
        if self.__is_primary_worker():
          self.__save_model()
          self.evaluate()

        if FLAGS.enbl_multi_gpu:
          self.mpi_comm.Barrier()

    if self.__is_primary_worker():
      self.__save_model()
      self.evaluate()
      self.__save_in_progress_pruned_model()

    if FLAGS.enbl_multi_gpu:
      self.max_save_path = self.mpi_comm.bcast(self.max_save_path, root=0)
    if self.__is_primary_worker():
      with self.pruner.model.g.as_default():
        #save_path = tf.train.latest_checkpoint(os.path.dirname(FLAGS.channel_pruned_path))
        self.pruner.saver = tf.train.Saver()
        self.pruner.saver.restore(self.pruner.model.sess, self.max_save_path)
        #self.pruner.save_model()

      #self.saver_train.restore(self.sess_train, self.max_save_path)
      #self.__save_model()

  def __save_best_pruned_model(self):
    """ save a in best purned model with a max evaluation result"""
    best_path = tf.train.Saver().save(self.pruner.model.sess, FLAGS.cp_best_path)
    tf.logging.info('model saved best model to ' + best_path)

  def __save_in_progress_pruned_model(self):
    """ save a in progress training model with a max evaluation result"""
    self.max_save_path = self.saver_eval.save(self.sess_eval, FLAGS.cp_best_path)
    tf.logging.info('model saved best model to ' + self.max_save_path)

  def __save_model(self):
    save_path = self.saver_train.save(self.sess_train, FLAGS.save_path, self.global_step)
    tf.logging.info('model saved to ' + save_path)

  def __restore_model(self, is_train):
    save_path = tf.train.latest_checkpoint(os.path.dirname(FLAGS.save_path))
    if is_train:
      self.saver_train.restore(self.sess_train, save_path)
    else:
      self.saver_eval.restore(self.sess_eval, save_path)
    tf.logging.info('model restored from ' + save_path)

  def __monitor_progress(self, summary, log_rslt):
    # early break for non-primary workers
    if not self.__is_primary_worker():
      return
    # write summaries for TensorBoard visualization
    self.sm_writer.add_summary(summary, self.idx_iter)

    # display monitored statistics
    lrn_rate, loss, accuracy = log_rslt[0], log_rslt[1], log_rslt[2]
    speed = FLAGS.batch_size * FLAGS.summ_step / (timer() - self.time_prev)
    if FLAGS.enbl_multi_gpu:
      speed *= mgw.size()
    tf.logging.info('iter #%d: lr = %e | loss = %e | speed = %.2f pics / sec'
                    % (self.idx_iter + 1, lrn_rate, loss, speed))
    for i in range(len(self.accuracy_keys)):
      tf.logging.info('{} = {}'.format(self.accuracy_keys[i], accuracy[i]))
    self.time_prev = timer()

  def __prune_and_finetune_uniform(self):
    '''prune with a list of compression ratio'''
    if self.__is_primary_worker():
      done = False
      self.pruner.extract_features()

      start = timer()
      while not done:
        _, _, done, _ = self.pruner.compress(FLAGS.cp_uniform_preserve_ratio)

      tf.logging.info('uniform channl pruning time cost: {}s'.format(timer() - start))
      self.pruner.save_model()

    if FLAGS.enbl_multi_gpu:
      self.mpi_comm.Barrier()

    self.__finetune_pruned_model(path=FLAGS.cp_channel_pruned_path)

  def __prune_and_finetune_list(self):
    '''prune with a list of compression ratio'''
    try:
      ratio_list = np.loadtxt(FLAGS.cp_prune_list_file, delimiter=',')
      ratio_list = list(ratio_list)
    except IOError as err:
      tf.logging.error('The prune list file format is not correct. \n \
        It\'s content should be a float list delimited by a comma.')
      raise err
    ratio_list.reverse()
    queue = deque(ratio_list)

    done = False
    while not done:
      done = self.__prune_list_layers(queue, [FLAGS.cp_list_group])


  def __prune_list_layers(self, queue, ps=None):
    for p in ps:
      done = self.__prune_n_layers(p, queue)
    return done

  def __prune_n_layers(self, n, queue):
    #self.max_eval_acc = 0
    done = False
    if self.__is_primary_worker():
      self.pruner.extract_features()
      done = False
      i = 0
      while not done and i < n:
        if not queue:
          ratio = 1
        else:
          ratio = queue.pop()
        _, _, done, _ = self.pruner.compress(ratio)
        i += 1

      self.pruner.save_model()

    if FLAGS.enbl_multi_gpu:
      self.mpi_comm.Barrier()
      done = self.mpi_comm.bcast(done, root=0)

    if done:
      self.__finetune_pruned_model(path=FLAGS.cp_channel_pruned_path, finetune=False)
    else:
      self.__finetune_pruned_model(path=FLAGS.cp_channel_pruned_path, finetune=FLAGS.cp_finetune)

    return done

  def __finetune_pruned_model(self, path=None, finetune=False):
    if path is None:
      path = FLAGS.cp_channel_pruned_path
    start = timer()
    tf.logging.info('build pruned evaluating model')
    self.__build_pruned_evaluate_model(path)
    tf.logging.info('build pruned training model')
    self.__build_pruned_train_model(path, finetune=finetune)
    tf.logging.info('training pruned model')
    self.__train_pruned_model(finetune=finetune)
    tf.logging.info('fintuning time cost: {}s'.format(timer() - start))

  def __prune_and_finetune_auto(self):
    if self.__is_primary_worker():
      self.__prune_rl()
      self.pruner.initialize_state()

    if FLAGS.enbl_multi_gpu:
      self.mpi_comm.Barrier()
      self.bestinfo = self.mpi_comm.bcast(self.bestinfo, root=0)

    ratio_list = self.bestinfo[0]
    tf.logging.info('best split ratio is: {}'.format(ratio_list))
    ratio_list.reverse()
    queue = deque(ratio_list)

    done = False
    while not done:
      done = self.__prune_list_layers(queue, [FLAGS.cp_list_group])

  @classmethod
  def __calc_reward(cls, accuracy, flops):
    if FLAGS.cp_reward_policy == 'accuracy':
      reward = accuracy * np.ones((1, 1))
    elif FLAGS.cp_reward_policy == 'flops':
      reward = -np.maximum(
        FLAGS.cp_noise_tolerance, (1 - accuracy)) * np.log(flops) * np.ones((1, 1))
    else:
      raise ValueError('unrecognized reward type: ' + FLAGS.cp_reward_policy)

    return reward

  def __prune_rl(self): # pylint: disable=too-many-locals
    """ search pruning strategy with reinforcement learning"""
    tf.logging.info(
      'preserve lower bound: {}, preserve ratio: {}, preserve upper bound: {}'.format(
        self.lbound, FLAGS.cp_preserve_ratio, self.rbound))
    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = str(0) # pylint: disable=no-member
    buf_size = len(self.pruner.states) * FLAGS.cp_nb_rlouts_min
    nb_rlouts = FLAGS.cp_nb_rlouts
    self.agent = DdpgAgent(
      tf.Session(config=config),
      len(self.pruner.states.loc[0].tolist()),
      1,
      nb_rlouts,
      buf_size,
      self.lbound,
      self.rbound)
    self.agent.init()
    self.bestinfo = None
    reward_best = np.NINF  # pylint: disable=no-member

    for idx_rlout in range(FLAGS.cp_nb_rlouts):
      # execute roll-outs to obtain pruning ratios
      self.agent.init_rlout()
      states_n_actions = []
      self.create_pruner()
      self.pruner.initialize_state()
      self.pruner.extract_features()
      state = np.array(self.pruner.currentStates.loc[0].tolist())[None, :]

      start = timer()
      while True:
        tf.logging.info('state is {}'.format(state))
        action = self.agent.sess.run(self.agent.actions_noisy, feed_dict={self.agent.states: state})
        tf.logging.info('RL choosed preserv ratio: {}'.format(action))
        state_next, acc_flops, done, real_action = self.pruner.compress(action)
        tf.logging.info('Actural preserv ratio: {}'.format(real_action))
        states_n_actions += [(state, real_action * np.ones((1, 1)))]
        state = state_next[None, :]
        actor_loss, critic_loss, noise_std = self.agent.train()
        if done:
          break
      tf.logging.info('roll-out #%d: a-loss = %.2e | c-loss = %.2e | noise std. = %.2e'
                      % (idx_rlout, actor_loss, critic_loss, noise_std))

      reward = self.__calc_reward(acc_flops[0], acc_flops[1])

      rewards = reward * np.ones(len(self.pruner.states))
      self.agent.finalize_rlout(rewards)

      # record transactions for RL training
      strategy = []
      for idx, (state, action) in enumerate(states_n_actions):
        strategy.append(action[0, 0])
        if idx != len(states_n_actions) - 1:
          terminal = np.zeros((1, 1))
          state_next = states_n_actions[idx + 1][0]
        else:
          terminal = np.ones((1, 1))
          state_next = np.zeros_like(state)
        self.agent.record(state, action, reward, terminal, state_next)

      # record the best combination of pruning ratios
      if reward_best < reward:
        tf.logging.info('best reward updated: %.4f -> %.4f' % (reward_best, reward))
        reward_best = reward
        self.bestinfo = [strategy, acc_flops[0], acc_flops[1]]
        tf.logging.info("""The best pruned model occured with
                strategy: {},
                accuracy: {} and
                pruned ratio: {}""".format(self.bestinfo[0], self.bestinfo[1], self.bestinfo[2]))

      tf.logging.info('automatic channl pruning time cost: {}s'.format(timer() - start))


  @classmethod
  def __is_primary_worker(cls):
    """Weather it is the primary worker"""
    return not FLAGS.enbl_multi_gpu or mgw.rank() == 0
예제 #8
0
class ChannelPrunedRmtLearner(AbstractLearner):  # pylint: disable=too-many-instance-attributes
    """Channel pruning learner - remastered."""
    def __init__(self, sm_writer, model_helper):
        """Constructor function.

    Args:
    * sm_writer: TensorFlow's summary writer
    * model_helper: model helper with definitions of model & dataset
    """

        # class-independent initialization
        super(ChannelPrunedRmtLearner, self).__init__(sm_writer, model_helper)

        # define scopes for full & channel-pruned models
        self.model_scope_full = 'model'
        self.model_scope_prnd = 'pruned_model'

        # download the pre-trained model
        if self.is_primary_worker('local'):
            self.download_model()  # pre-trained model is required
        self.auto_barrier()
        tf.logging.info('model files: ' + ', '.join(os.listdir('./models')))

        # class-dependent initialization
        if FLAGS.enbl_dst:
            self.helper_dst = DistillationHelper(sm_writer, model_helper,
                                                 self.mpi_comm)
        self.__build_train()
        self.__build_eval()

    def train(self):
        """Train a model and periodically produce checkpoint files."""

        # restore the full model from pre-trained checkpoints
        save_path = tf.train.latest_checkpoint(
            os.path.dirname(self.save_path_full))
        self.saver_full.restore(self.sess_train, save_path)

        # initialization
        self.sess_train.run(self.init_op)
        if FLAGS.enbl_multi_gpu:
            self.sess_train.run(self.bcast_op)

        # choose channels and evaluate the model before re-training
        time_prev = timer()
        self.__choose_channels()
        tf.logging.info('time (channel selection): %.2f (s)' %
                        (timer() - time_prev))
        self.sess_train.run(self.mask_updt_op)
        if FLAGS.enbl_multi_gpu:
            self.sess_train.run(self.bcast_op)

        # evaluate the model before fine-tuning
        if self.is_primary_worker('global'):
            self.__save_model(is_train=True)
            self.evaluate()
        self.auto_barrier()

        # fine-tune the model with chosen channels only
        time_prev = timer()
        for idx_iter in range(self.nb_iters_train):
            # train the model
            if (idx_iter + 1) % FLAGS.summ_step != 0:
                self.sess_train.run(self.train_op)
            else:
                __, summary, log_rslt = self.sess_train.run(
                    [self.train_op, self.summary_op, self.log_op])
                if self.is_primary_worker('global'):
                    time_step = timer() - time_prev
                    self.__monitor_progress(summary, log_rslt, idx_iter,
                                            time_step)
                    time_prev = timer()

            # save the model at certain steps
            if self.is_primary_worker('global') and (idx_iter +
                                                     1) % FLAGS.save_step == 0:
                self.__save_model(is_train=True)
                self.evaluate()
            self.auto_barrier()

        # save the final model
        if self.is_primary_worker('global'):
            self.__save_model(is_train=True)
            self.__restore_model(is_train=False)
            self.__save_model(is_train=False)
            self.evaluate()

    def evaluate(self):
        """Restore a model from the latest checkpoint files and then evaluate it."""

        self.__restore_model(is_train=False)
        nb_iters = int(
            np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval))
        eval_rslts = np.zeros((nb_iters, len(self.eval_op)))
        self.dump_n_eval(outputs=None, action='init')
        for idx_iter in range(nb_iters):
            if (idx_iter + 1) % 100 == 0:
                tf.logging.info('process the %d-th mini-batch for evaluation' %
                                (idx_iter + 1))
            eval_rslts[idx_iter], outputs = self.sess_eval.run(
                [self.eval_op, self.outputs_eval])
            self.dump_n_eval(outputs=outputs, action='dump')
        self.dump_n_eval(outputs=None, action='eval')
        for idx, name in enumerate(self.eval_op_names):
            tf.logging.info('%s = %.4e' % (name, np.mean(eval_rslts[:, idx])))

    def __build_train(self):  # pylint: disable=too-many-locals,too-many-statements
        """Build the training graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True  # pylint: disable=no-member
            config.gpu_options.visible_device_list = \
              str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(sess, images)

            # model definition - full model
            with tf.variable_scope(self.model_scope_full):
                __ = self.forward_train(images)
                self.vars_full = get_vars_by_scope(self.model_scope_full)
                self.saver_full = tf.train.Saver(self.vars_full['all'])
                self.save_path_full = FLAGS.save_path

            # model definition - channel-pruned model
            with tf.variable_scope(self.model_scope_prnd):
                logits_prnd = self.forward_train(images)
                self.vars_prnd = get_vars_by_scope(self.model_scope_prnd)
                self.conv_krnl_var_names = [
                    var.name for var in self.vars_prnd['conv_krnl']
                ]
                self.global_step = tf.train.get_or_create_global_step()
                self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all'] +
                                                       [self.global_step])

                # loss & extra evaluation metrics
                loss, metrics = self.calc_loss(labels, logits_prnd,
                                               self.vars_prnd['trainable'])
                if FLAGS.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits_prnd, logits_dst)
                tf.summary.scalar('loss', loss)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # learning rate schedule
                lrn_rate, self.nb_iters_train = self.setup_lrn_rate(
                    self.global_step)

                # calculate pruning ratios
                pr_trainable = calc_prune_ratio(self.vars_prnd['trainable'])
                pr_conv_krnl = calc_prune_ratio(self.vars_prnd['conv_krnl'])
                tf.summary.scalar('pr_trainable', pr_trainable)
                tf.summary.scalar('pr_conv_krnl', pr_conv_krnl)

                # create masks and corresponding operations for channel pruning
                self.masks = []
                mask_updt_ops = [
                ]  # update the mask based on convolutional kernel's value
                for idx, var in enumerate(self.vars_prnd['conv_krnl']):
                    tf.logging.info(
                        'creating a pruning mask for {} of size {}'.format(
                            var.name, var.shape))
                    mask_name = '/'.join(var.name.split('/')[1:]).replace(
                        ':0', '_mask')
                    mask_shape = [1, 1, var.shape[2], 1]  # 1 x 1 x c_in x 1
                    mask = tf.get_variable(mask_name,
                                           initializer=tf.ones(mask_shape),
                                           trainable=False)
                    var_norm = tf.reduce_sum(tf.square(var),
                                             axis=[0, 1, 3],
                                             keepdims=True)
                    self.masks += [mask]
                    mask_updt_ops += [
                        mask.assign(tf.cast(var_norm > 0.0, tf.float32))
                    ]
                self.mask_updt_op = tf.group(mask_updt_ops)

                # build operations for channel selection
                self.__build_chn_select_ops()

                # optimizer & gradients
                optimizer_base = tf.train.MomentumOptimizer(
                    lrn_rate, FLAGS.momentum)
                if not FLAGS.enbl_multi_gpu:
                    optimizer = optimizer_base
                else:
                    optimizer = mgw.DistributedOptimizer(optimizer_base)
                grads_origin = optimizer.compute_gradients(
                    loss, self.vars_prnd['trainable'])
                grads_pruned = self.__calc_grads_pruned(grads_origin)
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                               scope=self.model_scope_prnd)
                with tf.control_dependencies(update_ops):
                    self.train_op = optimizer.apply_gradients(
                        grads_pruned, global_step=self.global_step)

                # TF operations for initializing the channel-pruned model
                init_ops = []
                for var_full, var_prnd in zip(self.vars_full['all'],
                                              self.vars_prnd['all']):
                    init_ops += [var_prnd.assign(var_full)]
                init_ops += [self.global_step.initializer
                             ]  # initialize the global step
                init_ops += [
                    tf.variables_initializer(optimizer_base.variables())
                ]
                self.init_op = tf.group(init_ops)

            # TF operations for logging & summarizing
            self.sess_train = sess
            self.summary_op = tf.summary.merge_all()
            self.log_op = [lrn_rate, loss, pr_trainable, pr_conv_krnl] + list(
                metrics.values())
            self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_krn'] + list(
                metrics.keys())
            if FLAGS.enbl_multi_gpu:
                self.bcast_op = mgw.broadcast_global_variables(0)

    def __build_eval(self):
        """Build the evaluation graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True  # pylint: disable=no-member
            config.gpu_options.visible_device_list = \
              str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            self.sess_eval = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_eval()
                images, labels = iterator.get_next()

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(
                    self.sess_eval, images)

            # model definition - channel-pruned model
            with tf.variable_scope(self.model_scope_prnd):
                logits = self.forward_eval(images)
                vars_prnd = get_vars_by_scope(self.model_scope_prnd)
                global_step = tf.train.get_or_create_global_step()
                self.saver_prnd_eval = tf.train.Saver(vars_prnd['all'] +
                                                      [global_step])

                # loss & extra evaluation metrics
                loss, metrics = self.calc_loss(labels, logits,
                                               vars_prnd['trainable'])
                if FLAGS.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits, logits_dst)

                # calculate pruning ratios
                pr_trainable = calc_prune_ratio(vars_prnd['trainable'])
                pr_conv_krnl = calc_prune_ratio(vars_prnd['conv_krnl'])

                # TF operations for evaluation
                self.eval_op = [loss, pr_trainable, pr_conv_krnl] + list(
                    metrics.values())
                self.eval_op_names = ['loss', 'pr_trn', 'pr_krn'] + list(
                    metrics.keys())
                self.outputs_eval = logits

            # add input & output tensors to certain collections
            tf.add_to_collection('images_final', images)
            tf.add_to_collection('logits_final', logits)

    def __build_chn_select_ops(self):
        """Build channel selection operations for convolutional layers.

    Returns:
    * chn_select_ops: list of channel selection operations (one per convolutional layer)
    """

        # build layer-wise regression losses
        pattern = re.compile(r'/Conv2D$')
        conv_ops_full = get_ops_by_scope_n_pattern(self.model_scope_full,
                                                   pattern)
        conv_ops_prnd = get_ops_by_scope_n_pattern(self.model_scope_prnd,
                                                   pattern)
        reg_losses = []
        for conv_op_full, conv_op_prnd in zip(conv_ops_full, conv_ops_prnd):
            reg_losses += [
                tf.nn.l2_loss(conv_op_full.outputs[0] -
                              conv_op_prnd.outputs[0])
            ]

        # build layer-wise sampling operations
        conv_info_list = []
        for idx_layer, (conv_op_full, conv_op_prnd) in enumerate(
                zip(conv_ops_full, conv_ops_prnd)):
            conv_krnl_shape = self.vars_prnd['conv_krnl'][idx_layer].shape
            conv_krnl_prnd_ph = tf.placeholder(tf.float32,
                                               shape=conv_krnl_shape,
                                               name='conv_krnl_prnd_ph_%d' %
                                               idx_layer)
            conv_info_list += [{
                'conv_krnl_full':
                self.vars_full['conv_krnl'][idx_layer],
                'conv_krnl_prnd':
                self.vars_prnd['conv_krnl'][idx_layer],
                'conv_krnl_prnd_ph':
                conv_krnl_prnd_ph,
                'update_op':
                self.vars_prnd['conv_krnl'][idx_layer].assign(
                    conv_krnl_prnd_ph),
                'input_full':
                conv_op_full.inputs[0],
                'input_prnd':
                conv_op_prnd.inputs[0],
                'output_full':
                conv_op_full.outputs[0],
                'output_prnd':
                conv_op_prnd.outputs[0],
                'strides':
                conv_op_full.get_attr('strides'),
                'padding':
                conv_op_full.get_attr('padding').decode('utf-8'),
            }]

        # build meta LASSO/least-square optimization problems
        self.meta_lasso = self.__build_meta_lasso()
        self.meta_lstsq = self.__build_meta_lstsq()

        self.reg_losses = reg_losses
        self.conv_info_list = conv_info_list
        self.nb_conv_layers = len(self.reg_losses)

    def __build_meta_lasso(self):
        """Build a meta LASSO optimization problem."""

        # build a meta LASSO optimization problem
        with tf.variable_scope('meta_lasso'):
            # create placeholders to customize the LASSO problem
            xt_x_ph = tf.placeholder(tf.float32, name='xt_x_ph')
            xt_y_ph = tf.placeholder(tf.float32, name='xt_y_ph')
            mask_ph = tf.placeholder(tf.float32, name='mask_ph')
            gamma = tf.placeholder(tf.float32, shape=[], name='gamma')

            # create variables
            xt_x = tf.get_variable('xt_x',
                                   initializer=xt_x_ph,
                                   trainable=False,
                                   validate_shape=False)
            xt_y = tf.get_variable('xt_y',
                                   initializer=xt_y_ph,
                                   trainable=False,
                                   validate_shape=False)
            mask = tf.get_variable('mask',
                                   initializer=mask_ph,
                                   trainable=True,
                                   validate_shape=False)

            # TF operations
            def prox_mapping(x, thres):
                return tf.where(
                    x > thres, x - thres,
                    tf.where(x < -thres, x + thres, tf.zeros_like(x)))

            mask_gd = mask - FLAGS.cpr_ista_lrn_rate * (tf.matmul(xt_x, mask) -
                                                        xt_y)
            train_op = mask.assign(
                prox_mapping(mask_gd, gamma * FLAGS.cpr_ista_lrn_rate))
            init_op = tf.variables_initializer([xt_x, xt_y, mask])

        # pack placeholders, variables, and TF operations into dict
        meta_lasso = {
            'xt_x_ph': xt_x_ph,
            'xt_y_ph': xt_y_ph,
            'mask_ph': mask_ph,
            'gamma': gamma,
            'xt_x': xt_x,
            'xy_y': xt_y,
            'mask': mask,
            'init_op': init_op,
            'train_op': train_op,
        }

        return meta_lasso

    def __build_meta_lstsq(self):
        """Build a meta least-square optimization problem."""

        # build a meta least-square optimization problem
        with tf.variable_scope('meta_lstsq'):
            # create placeholders to customize the least-square problem
            feat_mat_ph = tf.placeholder(tf.float32, name='feat_mat_ph')
            rspn_mat_ph = tf.placeholder(tf.float32, name='rspn_mat_ph')

            # compute the closed-form solution
            wei_mat = tf.linalg.lstsq(feat_mat_ph, rspn_mat_ph,
                                      FLAGS.loss_w_dcy)

        # pack placeholders and variables into dict
        meta_lstsq = {
            'feat_mat_ph': feat_mat_ph,
            'rspn_mat_ph': rspn_mat_ph,
            'wei_mat': wei_mat,
        }

        return meta_lstsq

    def __calc_grads_pruned(self, grads_origin):
        """Calculate the mask-pruned gradients.

    Args:
    * grads_origin: list of original gradients

    Returns:
    * grads_pruned: list of mask-pruned gradients
    """

        grads_pruned = []
        conv_krnl_names = [var.name for var in self.vars_prnd['conv_krnl']]
        for grad in grads_origin:
            if grad[1].name not in conv_krnl_names:
                grads_pruned += [grad]
            else:
                idx_mask = conv_krnl_names.index(grad[1].name)
                grads_pruned += [(grad[0] * self.masks[idx_mask], grad[1])]

        return grads_pruned

    def __choose_channels(self):  # pylint: disable=too-many-locals
        """Choose channels for all convolutional layers."""

        # obtain each layer's pruning ratio
        prune_ratios = [FLAGS.cpr_prune_ratio] * self.nb_conv_layers
        if FLAGS.cpr_skip_frst_layer:
            prune_ratios[0] = 0.0
        if FLAGS.cpr_skip_last_layer:
            prune_ratios[-1] = 0.0

        # select channels for all the convolutional layers
        nb_workers = mgw.size() if FLAGS.enbl_multi_gpu else 1
        for idx_layer, (prune_ratio, conv_info) in enumerate(
                zip(prune_ratios, self.conv_info_list)):
            # skip if no pruning is required
            if prune_ratio == 0.0:
                continue
            if self.is_primary_worker('global'):
                tf.logging.info('layer #%d: pr = %.2f (target)' %
                                (idx_layer, prune_ratio))
                tf.logging.info('kernel shape = {}'.format(
                    conv_info['conv_krnl_prnd'].shape))

            # extract the current layer's information
            conv_krnl_full = self.sess_train.run(conv_info['conv_krnl_full'])
            conv_krnl_prnd = self.sess_train.run(conv_info['conv_krnl_prnd'])
            conv_krnl_prnd_ph = conv_info['conv_krnl_prnd_ph']
            update_op = conv_info['update_op']
            input_full_tf = conv_info['input_full']
            input_prnd_tf = conv_info['input_prnd']
            output_full_tf = conv_info['output_full']
            output_prnd_tf = conv_info['output_prnd']
            strides = conv_info['strides']
            padding = conv_info['padding']
            nb_chns_input = conv_krnl_prnd.shape[2]

            # sample inputs & outputs through multiple mini-batches
            nb_iters_smpl = int(
                math.ceil(float(FLAGS.cpr_nb_smpl_insts) / FLAGS.batch_size))
            inputs_list = [[] for __ in range(nb_chns_input)]
            outputs_list = []
            for idx_iter in range(nb_iters_smpl):
                inputs_full, inputs_prnd, outputs_full, outputs_prnd = \
                  self.sess_train.run([input_full_tf, input_prnd_tf, output_full_tf, output_prnd_tf])
                inputs_smpl, outputs_smpl = self.__smpl_inputs_n_outputs(
                    conv_krnl_full, conv_krnl_prnd, inputs_full, inputs_prnd,
                    outputs_full, outputs_prnd, strides, padding)
                for idx_chn_input in range(nb_chns_input):
                    inputs_list[idx_chn_input] += [inputs_smpl[idx_chn_input]]
                outputs_list += [outputs_smpl]
            inputs_np_list = [np.vstack(x) for x in inputs_list]
            outputs_np = np.vstack(outputs_list)

            # choose channels via solving the sparsity-constrained regression problem
            conv_krnl_prnd = self.__solve_sparse_regression(
                inputs_np_list, outputs_np, conv_krnl_prnd, prune_ratio)
            self.sess_train.run(update_op,
                                feed_dict={conv_krnl_prnd_ph: conv_krnl_prnd})

            # evaluate the channel pruned model
            if FLAGS.cpr_eval_per_layer:
                if self.is_primary_worker('global'):
                    self.__save_model(is_train=True)
                    self.evaluate()
                self.auto_barrier()

        # evaluate the final channel pruned model
        if not FLAGS.cpr_eval_per_layer:
            if self.is_primary_worker('global'):
                self.__save_model(is_train=True)
                self.evaluate()
            self.auto_barrier()

    def __smpl_inputs_n_outputs(self, conv_krnl_full, conv_krnl_prnd,
                                inputs_full, inputs_prnd, outputs_full,
                                outputs_prnd, strides, padding):
        """Sample inputs & outputs of sub-regions from full feature maps.

    Args:

    Returns:
    """

        # obtain parameters
        bs = inputs_full.shape[0]
        kh, kw = conv_krnl_full.shape[0], conv_krnl_full.shape[1]
        ih, iw, ic = inputs_full.shape[1], inputs_full.shape[
            2], inputs_full.shape[3]
        oh, ow, oc = outputs_full.shape[1], outputs_full.shape[
            2], outputs_full.shape[3]
        if padding == 'VALID':
            ph, pw = 0, 0
        else:
            ph = int(math.ceil((kh - 1) / 2))
            pw = int(math.ceil((kw - 1) / 2))

        # perform zero-padding on input feature maps
        if ph == 0 and pw == 0:
            inputs_full_pad = inputs_full
            inputs_prnd_pad = inputs_prnd
        else:
            inputs_full_pad = np.pad(inputs_full,
                                     ((0, ), (ph, ), (pw, ), (0, )),
                                     'constant')
            inputs_prnd_pad = np.pad(inputs_prnd,
                                     ((0, ), (ph, ), (pw, ), (0, )),
                                     'constant')

        # sample inputs & outputs of sub-regions
        inputs_smpl_full_list = []
        inputs_smpl_prnd_list = []
        outputs_smpl_full_list = []
        outputs_smpl_prnd_list = []
        for idx_iter in range(FLAGS.cpr_nb_smpl_crops):
            idx_oh = np.random.randint(oh)
            idx_ow = np.random.randint(ow)
            idx_ih_low = idx_oh * strides[1]
            idx_ih_hgh = idx_ih_low + kh
            idx_iw_low = idx_ow * strides[2]
            idx_iw_hgh = idx_iw_low + kw
            inputs_smpl_full_list += [
                inputs_full_pad[:, idx_ih_low:idx_ih_hgh,
                                idx_iw_low:idx_iw_hgh, :]
            ]
            inputs_smpl_prnd_list += [
                inputs_prnd_pad[:, idx_ih_low:idx_ih_hgh,
                                idx_iw_low:idx_iw_hgh, :]
            ]
            outputs_smpl_full_list += [
                np.reshape(outputs_full[:, idx_oh, idx_ow, :], [bs, -1])
            ]
            outputs_smpl_prnd_list += [
                np.reshape(outputs_prnd[:, idx_oh, idx_ow, :], [bs, -1])
            ]

        # concatenate samples into a single np.array
        inputs_smpl_full = np.concatenate(inputs_smpl_full_list, axis=0)
        inputs_smpl_prnd = np.concatenate(inputs_smpl_prnd_list, axis=0)
        outputs_smpl_full = np.vstack(outputs_smpl_full_list)
        outputs_smpl_prnd = np.vstack(outputs_smpl_prnd_list)

        # validate inputs & outputs
        wei_mat_full = np.reshape(conv_krnl_full, [-1, oc])
        wei_mat_prnd = np.reshape(conv_krnl_prnd, [-1, oc])
        preds_smpl_full = np.matmul(
            np.reshape(inputs_smpl_full, [-1, kh * kw * ic]), wei_mat_full)
        preds_smpl_prnd = np.matmul(
            np.reshape(inputs_smpl_prnd, [-1, kh * kw * ic]), wei_mat_prnd)
        err_full = norm(outputs_smpl_full -
                        preds_smpl_full)**2 / outputs_smpl_full.shape[0]
        err_prnd = norm(outputs_smpl_prnd -
                        preds_smpl_prnd)**2 / outputs_smpl_prnd.shape[0]
        assert err_full < 1e-10, 'unable to recover output feature maps - full (%e)' % err_full
        assert err_prnd < 1e-10, 'unable to recover output feature maps - prnd (%e)' % err_prnd

        # concatenate sampled inputs & outputs arrays
        inputs_smpl = np.split(inputs_smpl_prnd, ic,
                               axis=3)  # one per input channel
        for idx in range(ic):
            inputs_smpl[idx] = np.reshape(inputs_smpl[idx], [-1, kh * kw])
        outputs_smpl = outputs_smpl_full

        return inputs_smpl, outputs_smpl

    def __solve_sparse_regression(self, inputs_np_list, outputs_np, conv_krnl,
                                  prune_ratio):
        """Solve the sparsity-constrained regression problem.

    Args:
    * inputs_np_list: list of input feature maps (one per input channel, N x k^2)
    * outputs_np: output feature maps (N x c_o)
    * conv_krnl: initial convolutional kernel (k * k * c_i * c_o)
    * prune_ratio: pruning ratio

    Returns:
    * conv_krnl: updated convolutional kernel (k * k * c_i * c_o)
    """

        # obtain parameters
        bs = outputs_np.shape[0]
        kh, kw, ic, oc = conv_krnl.shape[0], conv_krnl.shape[
            1], conv_krnl.shape[2], conv_krnl.shape[3]

        # compute the feature matrix & response vector
        rspn_vec_np = np.reshape(outputs_np, [-1, 1])  # N' x 1 (N' = N * c_o)
        feat_mat_np = np.zeros((rspn_vec_np.shape[0], ic))  # N' x c_i
        for idx in range(ic):
            wei_mat = np.reshape(conv_krnl[:, :, idx, :], [kh * kw, oc])
            feat_mat_np[:, idx] = np.matmul(inputs_np_list[idx],
                                            wei_mat).ravel()

        # compute <X^T * X> & <X^T * y> in advance
        xt_x_np = np.matmul(feat_mat_np.T, feat_mat_np) / bs
        xt_y_np = np.matmul(feat_mat_np.T, rspn_vec_np) / bs
        mask_np_init = np.ones((ic, 1))

        # solve the LASSO problem
        def __solve_lasso(x):
            self.sess_train.run(self.meta_lasso['init_op'],
                                feed_dict={
                                    self.meta_lasso['xt_x_ph']: xt_x_np,
                                    self.meta_lasso['xt_y_ph']: xt_y_np,
                                    self.meta_lasso['mask_ph']: mask_np_init,
                                })
            for __ in range(FLAGS.cpr_ista_nb_iters):
                self.sess_train.run(self.meta_lasso['train_op'],
                                    feed_dict={self.meta_lasso['gamma']: x})
            mask_np = self.sess_train.run(self.meta_lasso['mask'])
            nb_chns_nnz = np.count_nonzero(mask_np)
            tf.logging.info('x = %e -> nb_chns_nnz = %d' % (x, nb_chns_nnz))
            return mask_np, nb_chns_nnz

        # determine <gamma> via binary search
        val = 0.1
        nb_chns_nnz_target = int(ic * (1.0 - prune_ratio))
        while True:
            mask_np, nb_chns_nnz = __solve_lasso(val)
            if nb_chns_nnz > nb_chns_nnz_target:
                val *= 2.0
            else:
                break
        lbnd = val / 2.0
        ubnd = val
        while True:
            val = (lbnd + ubnd) / 2.0
            mask_np, nb_chns_nnz = __solve_lasso(val)
            if nb_chns_nnz < nb_chns_nnz_target:
                ubnd = val
            elif nb_chns_nnz > nb_chns_nnz_target:
                lbnd = val
            else:
                break
        tf.logging.info('gamma-final: %e' % val)

        # construct a least-square regression problem
        rspn_mat_np = outputs_np
        bnry_vec_np = (mask_np > 0.0)
        inputs_np_list_msk = [
            bnry_vec_np[idx] * inputs_np_list[idx] for idx in range(ic)
        ]
        feat_mat_np = np.reshape(
            np.concatenate(
                [np.expand_dims(x, axis=-1) for x in inputs_np_list_msk],
                axis=-1), [bs, -1])
        wei_mat_np = self.sess_train.run(self.meta_lstsq['wei_mat'],
                                         feed_dict={
                                             self.meta_lstsq['feat_mat_ph']:
                                             feat_mat_np,
                                             self.meta_lstsq['rspn_mat_ph']:
                                             rspn_mat_np,
                                         })
        conv_krnl = np.reshape(wei_mat_np, conv_krnl.shape) * np.reshape(
            bnry_vec_np, [1, 1, -1, 1])

        return conv_krnl

    def __save_model(self, is_train):
        """Save the current model for training or evaluation.

    Args:
    * is_train: whether to save a model for training
    """

        if is_train:
            save_path = self.saver_prnd_train.save(self.sess_train,
                                                   FLAGS.cpr_save_path,
                                                   self.global_step)
        else:
            save_path = self.saver_prnd_eval.save(self.sess_eval,
                                                  FLAGS.cpr_save_path_eval)
        tf.logging.info('model saved to ' + save_path)

    def __restore_model(self, is_train):
        """Restore a model from the latest checkpoint files.

    Args:
    * is_train: whether to restore a model for training
    """

        save_path = tf.train.latest_checkpoint(
            os.path.dirname(FLAGS.cpr_save_path))
        if is_train:
            self.saver_prnd_train.restore(self.sess_train, save_path)
        else:
            self.saver_prnd_eval.restore(self.sess_eval, save_path)
        tf.logging.info('model restored from ' + save_path)

    def __monitor_progress(self, summary, log_rslt, idx_iter, time_step):
        """Monitor the training progress.

    Args:
    * summary: summary protocol buffer
    * log_rslt: logging operations' results
    * idx_iter: index of the training iteration
    * time_step: time step between two summary operations
    """

        # write summaries for TensorBoard visualization
        self.sm_writer.add_summary(summary, idx_iter)

        # compute the training speed
        speed = FLAGS.batch_size * FLAGS.summ_step / time_step
        if FLAGS.enbl_multi_gpu:
            speed *= mgw.size()

        # display monitored statistics
        log_str = ' | '.join([
            '%s = %.4e' % (name, value)
            for name, value in zip(self.log_op_names, log_rslt)
        ])
        tf.logging.info('iter #%d: %s | speed = %.2f pics / sec' %
                        (idx_iter + 1, log_str, speed))
예제 #9
0
class ChannelPrunedRmtLearner(AbstractLearner):  # pylint: disable=too-many-instance-attributes
    """Channel pruning learner - remastered."""
    def __init__(self, sm_writer, model_helper):
        """Constructor function.

        Args:
        * sm_writer: TensorFlow's summary writer
        * model_helper: model helper with definitions of model & dataset
        """

        # class-independent initialization
        super(ChannelPrunedRmtLearner, self).__init__(sm_writer, model_helper)

        # define scopes for full & channel-pruned models
        self.model_scope_full = 'model'
        self.model_scope_prnd = 'pruned_model'

        # download the pre-trained model
        if self.is_primary_worker('local'):
            self.download_model()  # pre-trained model is required
        self.auto_barrier()
        tf.logging.info('model files: ' + ', '.join(os.listdir('./models')))

        # class-dependent initialization
        if FLAGS.enbl_dst:
            self.helper_dst = DistillationHelper(sm_writer, model_helper,
                                                 self.mpi_comm)
        self.__build_train()
        self.__build_eval()

        # build the channel pruning graph
        self.__build_prune()

    def train(self):
        """Train a model and periodically produce checkpoint files."""

        # choose channels or directly load a pre-pruned model as warm-start
        if not FLAGS.cpr_warm_start:
            time_prev = timer()
            self.__choose_channels()
            tf.logging.info('time (channel selection): %.2f (s)' %
                            (timer() - time_prev))
        save_path = tf.train.latest_checkpoint(
            os.path.dirname(FLAGS.cpr_save_path_ws))
        self.saver_prnd_train.restore(self.sess_train, save_path)
        tf.logging.info('model restored from ' + save_path)

        # initialize all the remaining variables and then broadcast
        self.sess_train.run(self.init_op)
        if FLAGS.enbl_multi_gpu:
            self.sess_train.run(self.bcast_op)

        # evaluate the model before fine-tuning
        if self.is_primary_worker('global'):
            self.__save_model(is_train=True)
            self.evaluate()
        self.auto_barrier()

        # fine-tune the model with chosen channels only
        time_prev = timer()
        for idx_iter in range(self.nb_iters_train):
            # train the model
            if (idx_iter + 1) % FLAGS.summ_step != 0:
                self.sess_train.run(self.train_op)
            else:
                __, summary, log_rslt = self.sess_train.run(
                    [self.train_op, self.summary_op, self.log_op])
                if self.is_primary_worker('global'):
                    time_step = timer() - time_prev
                    self.__monitor_progress(summary, log_rslt, idx_iter,
                                            time_step)
                    time_prev = timer()

            # save the model at certain steps
            if self.is_primary_worker('global') and (idx_iter +
                                                     1) % FLAGS.save_step == 0:
                self.__save_model(is_train=True)
                self.evaluate()
            self.auto_barrier()

        # save the final model
        if self.is_primary_worker('global'):
            self.__save_model(is_train=True)
            self.__restore_model(is_train=False)
            self.__save_model(is_train=False)
            self.evaluate()

    def evaluate(self):
        """Restore a model from the latest checkpoint files and then evaluate it."""

        self.__restore_model(is_train=False)
        nb_iters = int(
            np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval))
        eval_rslts = np.zeros((nb_iters, len(self.eval_op)))
        self.dump_n_eval(outputs=None, action='init')
        for idx_iter in range(nb_iters):
            if (idx_iter + 1) % 100 == 0:
                tf.logging.info('process the %d-th mini-batch for evaluation' %
                                (idx_iter + 1))
            eval_rslts[idx_iter], outputs = self.sess_eval.run(
                [self.eval_op, self.outputs_eval])
            self.dump_n_eval(outputs=outputs, action='dump')
        self.dump_n_eval(outputs=None, action='eval')
        for idx, name in enumerate(self.eval_op_names):
            tf.logging.info('%s = %.4e' % (name, np.mean(eval_rslts[:, idx])))

    def __build_train(self):  # pylint: disable=too-many-locals,too-many-statements
        """Build the training graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True  # pylint: disable=no-member
            config.gpu_options.visible_device_list = \
                str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(sess, images)

            # model definition - channel-pruned model
            with tf.variable_scope(self.model_scope_prnd):
                logits_prnd = self.forward_train(images)
                self.vars_prnd = get_vars_by_scope(self.model_scope_prnd)
                self.global_step = tf.train.get_or_create_global_step()
                self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all'] +
                                                       [self.global_step])

                # loss & extra evaluation metrics
                loss, metrics = self.calc_loss(labels, logits_prnd,
                                               self.vars_prnd['trainable'])
                if FLAGS.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits_prnd, logits_dst)
                tf.summary.scalar('loss', loss)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # learning rate schedule
                lrn_rate, self.nb_iters_train = self.setup_lrn_rate(
                    self.global_step)

                # calculate pruning ratios
                pr_trainable = calc_prune_ratio(self.vars_prnd['trainable'])
                pr_conv_krnl = calc_prune_ratio(self.vars_prnd['conv_krnl'])
                tf.summary.scalar('pr_trainable', pr_trainable)
                tf.summary.scalar('pr_conv_krnl', pr_conv_krnl)

                # create masks and corresponding operations for channel pruning
                self.masks = []
                for idx, var in enumerate(self.vars_prnd['conv_krnl']):
                    tf.logging.info(
                        'creating a pruning mask for {} of size {}'.format(
                            var.name, var.shape))
                    mask_name = '/'.join(var.name.split('/')[1:]).replace(
                        ':0', '_mask')
                    var_norm = tf.reduce_sum(tf.square(var),
                                             axis=[0, 1, 3],
                                             keepdims=True)
                    mask_init = tf.cast(var_norm > 0.0, tf.float32)
                    mask = tf.get_variable(mask_name,
                                           initializer=mask_init,
                                           trainable=False)
                    self.masks += [mask]

                # optimizer & gradients
                optimizer_base = tf.train.MomentumOptimizer(
                    lrn_rate, FLAGS.momentum)
                if not FLAGS.enbl_multi_gpu:
                    optimizer = optimizer_base
                else:
                    optimizer = mgw.DistributedOptimizer(optimizer_base)
                grads_origin = optimizer.compute_gradients(
                    loss, self.vars_prnd['trainable'])
                grads_pruned = self.__calc_grads_pruned(grads_origin)
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                               scope=self.model_scope_prnd)
                with tf.control_dependencies(update_ops):
                    self.train_op = optimizer.apply_gradients(
                        grads_pruned, global_step=self.global_step)

            # TF operations for logging & summarizing
            self.sess_train = sess
            self.summary_op = tf.summary.merge_all()
            self.init_op = tf.group(
                tf.variables_initializer([self.global_step] + self.masks +
                                         optimizer_base.variables()))
            self.log_op = [lrn_rate, loss, pr_trainable, pr_conv_krnl] + list(
                metrics.values())
            self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_krn'] + list(
                metrics.keys())
            if FLAGS.enbl_multi_gpu:
                self.bcast_op = mgw.broadcast_global_variables(0)

    def __build_eval(self):
        """Build the evaluation graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True  # pylint: disable=no-member
            config.gpu_options.visible_device_list = \
                str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            self.sess_eval = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_eval()
                images, labels = iterator.get_next()

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(
                    self.sess_eval, images)

            # model definition - channel-pruned model
            with tf.variable_scope(self.model_scope_prnd):
                logits = self.forward_eval(images)
                vars_prnd = get_vars_by_scope(self.model_scope_prnd)
                global_step = tf.train.get_or_create_global_step()
                self.saver_prnd_eval = tf.train.Saver(vars_prnd['all'] +
                                                      [global_step])

                # loss & extra evaluation metrics
                loss, metrics = self.calc_loss(labels, logits,
                                               vars_prnd['trainable'])
                if FLAGS.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits, logits_dst)

                # calculate pruning ratios
                pr_trainable = calc_prune_ratio(vars_prnd['trainable'])
                pr_conv_krnl = calc_prune_ratio(vars_prnd['conv_krnl'])

                # TF operations for evaluation
                self.eval_op = [loss, pr_trainable, pr_conv_krnl] + list(
                    metrics.values())
                self.eval_op_names = ['loss', 'pr_trn', 'pr_krn'] + list(
                    metrics.keys())
                self.outputs_eval = logits

            # add input & output tensors to certain collections
            tf.add_to_collection('images_final', images)
            tf.add_to_collection('logits_final', logits)

    def __build_prune(self):
        """Build the channel pruning graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True  # pylint: disable=no-member
            config.gpu_options.visible_device_list = \
                str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()
                if not isinstance(images, dict):
                    images_ph = tf.placeholder(tf.float32,
                                               shape=images.shape,
                                               name='images_ph')
                else:
                    images_ph = {}
                    for key, value in images.items():
                        images_ph[key] = tf.placeholder(value.dtype,
                                                        shape=value.shape,
                                                        name=(key + '_ph'))

            # restore a pre-trained model as full model
            with tf.variable_scope(self.model_scope_full):
                __ = self.forward_train(images_ph)
                vars_full = get_vars_by_scope(self.model_scope_full)
                saver_full = tf.train.Saver(vars_full['all'])
                saver_full.restore(
                    sess,
                    tf.train.latest_checkpoint(os.path.dirname(
                        FLAGS.save_path)))

            # restore a pre-trained model as channel-pruned model
            with tf.variable_scope(self.model_scope_prnd):
                logits_prnd = self.forward_train(images_ph)
                vars_prnd = get_vars_by_scope(self.model_scope_prnd)
                global_step = tf.train.get_or_create_global_step()
                saver_prnd = tf.train.Saver(vars_prnd['all'] + [global_step])

                # loss & extra evaluation metrics
                loss, metrics = self.calc_loss(labels, logits_prnd,
                                               vars_prnd['trainable'])

                # calculate pruning ratios
                pr_trainable = calc_prune_ratio(vars_prnd['trainable'])
                pr_conv_krnl = calc_prune_ratio(vars_prnd['conv_krnl'])

                # use full model's weights to initialize channel-pruned model
                init_ops = [global_step.initializer]
                for var_full, var_prnd in zip(vars_full['all'],
                                              vars_prnd['all']):
                    init_ops += [var_prnd.assign(var_full)]
                self.init_op_prune = tf.group(init_ops)

            # build a list of Conv2D operation's information
            self.conv_info_list = self.__build_conv_info_list(
                vars_prnd['conv_krnl'])

            # build meta LASSO/least-square optimization problems
            self.meta_lasso = self.__build_meta_lasso()
            self.meta_lstsq = self.__build_meta_lstsq()

            # TF operations for logging & summarizing
            self.sess_prune = sess
            self.images_prune = images
            self.images_prune_ph = images_ph
            self.saver_prune = saver_prnd
            self.pr_trn_prune = pr_trainable
            self.pr_krn_prune = pr_conv_krnl

    def __build_conv_info_list(self, conv_krnls_prnd):
        """Build a list of Conv2D operation's information.

        Args:
        * conv_krnls_prnd: list of convolutional kernels in the channel-pruned model

        Returns:
        * conv_info_list: list of Conv2D operation's information
        """

        # find all the Conv2D operations
        pattern = re.compile(r'/Conv2D$')
        conv_ops_full = get_ops_by_scope_n_pattern(self.model_scope_full,
                                                   pattern)
        conv_ops_prnd = get_ops_by_scope_n_pattern(self.model_scope_prnd,
                                                   pattern)

        # build a list of Conv2D operation's information
        conv_info_list = []
        for idx_layer, (conv_op_full, conv_op_prnd) in enumerate(
                zip(conv_ops_full, conv_ops_prnd)):
            conv_krnl_prnd = conv_krnls_prnd[idx_layer]
            conv_krnl_prnd_ph = tf.placeholder(tf.float32,
                                               shape=conv_krnl_prnd.shape,
                                               name='conv_krnl_prnd_ph_%d' %
                                               idx_layer)
            conv_info_list += [{
                'conv_krnl_full':
                conv_op_full.inputs[1],
                'conv_krnl_prnd':
                conv_op_prnd.inputs[1],
                'conv_krnl_prnd_ph':
                conv_krnl_prnd_ph,
                'update_op':
                conv_krnl_prnd.assign(conv_krnl_prnd_ph),
                'input_full':
                conv_op_full.inputs[0],
                'input_prnd':
                conv_op_prnd.inputs[0],
                'output_full':
                conv_op_full.outputs[0],
                'output_prnd':
                conv_op_prnd.outputs[0],
                'strides':
                conv_op_full.get_attr('strides'),
                'padding':
                conv_op_full.get_attr('padding').decode('utf-8'),
            }]

        return conv_info_list

    def __build_meta_lasso(self):
        """Build a meta LASSO optimization problem."""

        # build a meta LASSO optimization problem
        with tf.variable_scope('meta_lasso'):
            # create placeholders to customize the LASSO problem
            xt_x_ph = tf.placeholder(tf.float32, name='xt_x_ph')
            xt_y_ph = tf.placeholder(tf.float32, name='xt_y_ph')
            mask_ph = tf.placeholder(tf.float32, name='mask_ph')
            gamma = tf.placeholder(tf.float32, shape=[], name='gamma')

            # create variables
            xt_x = tf.get_variable('xt_x',
                                   initializer=xt_x_ph,
                                   trainable=False,
                                   validate_shape=False)
            xt_y = tf.get_variable('xt_y',
                                   initializer=xt_y_ph,
                                   trainable=False,
                                   validate_shape=False)
            mask = tf.get_variable('mask',
                                   initializer=mask_ph,
                                   trainable=True,
                                   validate_shape=False)

            # TF operations
            def prox_mapping(x, thres):
                return tf.where(
                    x > thres, x - thres,
                    tf.where(x < -thres, x + thres, tf.zeros_like(x)))

            mask_gd = mask - FLAGS.cpr_ista_lrn_rate * (tf.matmul(xt_x, mask) -
                                                        xt_y)
            train_op = mask.assign(
                prox_mapping(mask_gd, gamma * FLAGS.cpr_ista_lrn_rate))
            init_op = tf.variables_initializer([xt_x, xt_y, mask])

        # pack placeholders, variables, and TF operations into dict
        meta_lasso = {
            'xt_x_ph': xt_x_ph,
            'xt_y_ph': xt_y_ph,
            'mask_ph': mask_ph,
            'gamma': gamma,
            'xt_x': xt_x,
            'xy_y': xt_y,
            'mask': mask,
            'init_op': init_op,
            'train_op': train_op,
        }

        return meta_lasso

    def __build_meta_lstsq(self):
        """Build a meta least-square optimization problem."""

        # build a meta least-square optimization problem
        beta1 = 0.9
        beta2 = 0.999
        epsilon = 1e-8
        with tf.variable_scope('meta_lstsq'):
            # create placeholders to customize the least-square problem
            x_mat_ph = tf.placeholder(tf.float32, name='x_mat_ph')
            y_mat_ph = tf.placeholder(tf.float32, name='y_mat_ph')
            w_mat_ph = tf.placeholder(tf.float32, name='w_mat_ph')
            gacc1_ph = tf.placeholder(tf.float32, name='gacc1_ph')
            gacc2_ph = tf.placeholder(tf.float32, name='gacc2_ph')

            # create variables
            x_mat = tf.get_variable('x_mat',
                                    initializer=x_mat_ph,
                                    validate_shape=False)
            y_mat = tf.get_variable('y_mat',
                                    initializer=y_mat_ph,
                                    validate_shape=False)
            w_mat = tf.get_variable('w_mat',
                                    initializer=w_mat_ph,
                                    validate_shape=False)
            gacc1 = tf.get_variable('gacc1',
                                    initializer=gacc1_ph,
                                    validate_shape=False)
            gacc2 = tf.get_variable('gacc2',
                                    initializer=gacc2_ph,
                                    validate_shape=False)
            train_step = tf.get_variable('train_step',
                                         shape=[],
                                         initializer=tf.zeros_initializer)

            # TF operations
            nb_smpls = tf.cast(tf.shape(x_mat)[0], tf.float32)
            loss_reg = tf.nn.l2_loss(tf.matmul(x_mat, w_mat) -
                                     y_mat) / nb_smpls
            loss_dcy = FLAGS.loss_w_dcy * tf.nn.l2_loss(w_mat)
            grad = tf.matmul(tf.transpose(x_mat),
                             tf.matmul(x_mat, w_mat) -
                             y_mat) / nb_smpls + FLAGS.loss_w_dcy * w_mat
            update_ops = [
                gacc1.assign(beta1 * gacc1 + (1.0 - beta1) * grad),
                gacc2.assign(beta2 * gacc2 + (1.0 - beta2) * grad**2),
                train_step.assign_add(tf.ones([]))
            ]
            with tf.control_dependencies(update_ops):
                lrn_rate = FLAGS.cpr_lstsq_lrn_rate \
                           * tf.sqrt(1.0 - tf.pow(beta2, train_step)) / (1.0 - tf.pow(beta1, train_step))
                train_op = w_mat.assign_add(-lrn_rate * gacc1 /
                                            (tf.sqrt(gacc2) + epsilon))
            init_op = tf.variables_initializer(
                [x_mat, y_mat, w_mat, gacc1, gacc2, train_step])

        # pack placeholders and variables into dict
        meta_lstsq = {
            'x_mat_ph': x_mat_ph,
            'y_mat_ph': y_mat_ph,
            'w_mat_ph': w_mat_ph,
            'gacc1_ph': gacc1_ph,
            'gacc2_ph': gacc2_ph,
            'w_mat': w_mat,
            'loss_reg': loss_reg,
            'loss_dcy': loss_dcy,
            'init_op': init_op,
            'train_op': train_op,
        }

        return meta_lstsq

    def __calc_grads_pruned(self, grads_origin):
        """Calculate the mask-pruned gradients.

        Args:
        * grads_origin: list of original gradients

        Returns:
        * grads_pruned: list of mask-pruned gradients
        """

        grads_pruned = []
        conv_krnl_names = [var.name for var in self.vars_prnd['conv_krnl']]
        for grad in grads_origin:
            if grad[1].name not in conv_krnl_names:
                grads_pruned += [grad]
            else:
                idx_mask = conv_krnl_names.index(grad[1].name)
                grads_pruned += [(grad[0] * self.masks[idx_mask], grad[1])]

        return grads_pruned

    def __choose_channels(self):  # pylint: disable=too-many-locals
        """Choose channels for all convolutional layers."""

        # configure each layer's pruning ratio
        nb_layers = len(self.conv_info_list)
        prune_ratios = [FLAGS.cpr_prune_ratio] * nb_layers
        if FLAGS.cpr_skip_frst_layer:
            prune_ratios[0] = 0.0
        if FLAGS.cpr_skip_last_layer:
            prune_ratios[-1] = 0.0

        # skip channel pruning at certain layers
        skip_names = FLAGS.cpr_skip_op_names.split(
            ',') if FLAGS.cpr_skip_op_names is not None else []
        for idx_layer in range(nb_layers):
            # if self.conv_info_list[idx_layer]['input_full'].shape[2] == 8:
            #  prune_ratios[idx_layer] = 0.0
            conv_krnl_prnd_name = self.conv_info_list[idx_layer][
                'conv_krnl_prnd'].name
            for skip_name in skip_names:
                if skip_name in conv_krnl_prnd_name:
                    prune_ratios[idx_layer] = 0.0
                    tf.logging.info('skip %s since no pruning is required' %
                                    conv_krnl_prnd_name)
                    break

        # cache multiple mini-batches of images for channel selection
        def __build_feed_dict(images_np):
            if not isinstance(self.images_prune, dict):
                feed_dict = {self.images_prune_ph: images_np}
            else:
                feed_dict = {}
                for key in self.images_prune:
                    feed_dict[self.images_prune_ph[key]] = images_np[key]
            return feed_dict

        nb_mbtcs = int(math.ceil(FLAGS.cpr_nb_smpls / FLAGS.batch_size))
        images_cached = []
        for __ in range(nb_mbtcs):
            images_cached += [self.sess_prune.run(self.images_prune)]

        # select channels for all the convolutional layers
        self.sess_prune.run(self.init_op_prune)
        for idx_layer in range(nb_layers):
            # display the layer information
            prune_ratio = prune_ratios[idx_layer]
            conv_info = self.conv_info_list[idx_layer]
            if self.is_primary_worker('global'):
                tf.logging.info('layer #%d: pr = %.2f (target)' %
                                (idx_layer, prune_ratio))
                tf.logging.info('kernel name = {}'.format(
                    conv_info['conv_krnl_prnd'].name))
                tf.logging.info('kernel shape = {}'.format(
                    conv_info['conv_krnl_prnd'].shape))

            # extract the current layer's information
            conv_krnl_full = self.sess_prune.run(conv_info['conv_krnl_full'])
            conv_krnl_prnd = self.sess_prune.run(conv_info['conv_krnl_prnd'])
            conv_krnl_prnd_ph = conv_info['conv_krnl_prnd_ph']
            update_op = conv_info['update_op']
            input_full_tf = conv_info['input_full']
            input_prnd_tf = conv_info['input_prnd']
            output_full_tf = conv_info['output_full']
            output_prnd_tf = conv_info['output_prnd']
            strides = conv_info['strides']
            padding = conv_info['padding']
            nb_chns_input = conv_krnl_prnd.shape[2]

            # sample inputs & outputs through multiple mini-batches
            tf.logging.info(
                'sampling inputs & outputs through multiple mini-batches')
            time_beg = timer()
            nb_insts = 0  # number of sampled instances (for regression) collected so far
            nb_insts_min = FLAGS.cpr_nb_crops_per_smpl * FLAGS.cpr_nb_smpls  # minimal requirement
            inputs_list = [[] for __ in range(nb_chns_input)]
            outputs_list = []
            for idx_mbtc in range(nb_mbtcs):
                inputs_full, inputs_prnd, outputs_full, outputs_prnd = \
                    self.sess_prune.run([input_full_tf, input_prnd_tf, output_full_tf, output_prnd_tf],
                                        feed_dict=__build_feed_dict(images_cached[idx_mbtc]))
                inputs_smpl, outputs_smpl = self.__smpl_inputs_n_outputs(
                    conv_krnl_full, conv_krnl_prnd, inputs_full, inputs_prnd,
                    outputs_full, outputs_prnd, strides, padding)
                nb_insts += outputs_smpl.shape[0]
                for idx_chn_input in range(nb_chns_input):
                    inputs_list[idx_chn_input] += [inputs_smpl[idx_chn_input]]
                outputs_list += [outputs_smpl]
                if nb_insts > nb_insts_min:
                    break
            idxs_inst = np.random.choice(nb_insts,
                                         size=(nb_insts_min),
                                         replace=False)
            inputs_np_list = [np.vstack(x)[idxs_inst] for x in inputs_list]
            outputs_np = np.vstack(outputs_list)[idxs_inst]
            tf.logging.info('time elapsed (sampling): %.4f (s)' %
                            (timer() - time_beg))

            # choose channels via solving the sparsity-constrained regression problem
            tf.logging.info(
                'choosing channels via solving the sparsity-constrained regression problem'
            )
            time_beg = timer()
            conv_krnl_prnd = self.__solve_sparse_regression(
                inputs_np_list, outputs_np, conv_krnl_prnd, prune_ratio)
            self.sess_prune.run(update_op,
                                feed_dict={conv_krnl_prnd_ph: conv_krnl_prnd})
            tf.logging.info('time elapsed (selection): %.4f (s)' %
                            (timer() - time_beg))

            # compute the overall pruning ratios
            pr_trn, pr_krn = self.sess_prune.run(
                [self.pr_trn_prune, self.pr_krn_prune])
            tf.logging.info('pruning ratios: %e (trn) / %e (krn)' %
                            (pr_trn, pr_krn))

        # save the temporary model containing channel pruned weights
        if self.is_primary_worker('global'):
            save_path = self.saver_prune.save(self.sess_prune,
                                              FLAGS.cpr_save_path_ws)
            tf.logging.info('model saved to ' + save_path)
        self.auto_barrier()

    def __smpl_inputs_n_outputs(self, conv_krnl_full, conv_krnl_prnd,
                                inputs_full, inputs_prnd, outputs_full,
                                outputs_prnd, strides, padding):
        """Sample inputs & outputs of sub-regions from full feature maps.

        Args:

        Returns:
        """

        # obtain parameters
        bs = inputs_full.shape[0]
        kh, kw = conv_krnl_full.shape[0], conv_krnl_full.shape[1]
        ih, iw, ic = inputs_full.shape[1], inputs_full.shape[
            2], inputs_full.shape[3]
        oh, ow, oc = outputs_full.shape[1], outputs_full.shape[
            2], outputs_full.shape[3]
        sh, sw = strides[1], strides[2]
        if padding == 'VALID':
            pt, pb, pl, pr = 0, 0, 0, 0  # padding - top / bottom / left / right
        else:
            # ref link: https://www.tensorflow.org/api_guides/python/nn#Convolution
            ph = max(kh - (sh if ih % sh == 0 else ih % sh), 0)
            pw = max(kw - (sw if iw % sw == 0 else iw % sw), 0)
            pt, pb = ph // 2, ph % 2
            pl, pr = pw // 2, pw % 2

        # sample inputs & outputs of sub-regions
        inputs_smpl_full_list = []
        inputs_smpl_prnd_list = []
        outputs_smpl_full_list = []
        outputs_smpl_prnd_list = []
        for idx_iter in range(FLAGS.cpr_nb_crops_per_smpl):
            idx_oh = np.random.randint(oh)
            idx_ow = np.random.randint(ow)
            idx_ih_low = idx_oh * strides[
                1] - pt  # uncropped indices of input feature maps
            idx_ih_hgh = idx_ih_low + kh
            idx_iw_low = idx_ow * strides[2] - pl
            idx_iw_hgh = idx_iw_low + kw
            idx_sh_low = max(-idx_ih_low,
                             0)  # cropped indices of sampled feature maps
            idx_sh_hgh = kh - max(idx_ih_hgh - ih, 0)
            idx_sw_low = max(-idx_iw_low, 0)
            idx_sw_hgh = kw - max(idx_iw_hgh - iw, 0)
            idx_ih_low = max(idx_ih_low,
                             0)  # cropped indices of input feature maps
            idx_ih_hgh = min(idx_ih_hgh, ih)
            idx_iw_low = max(idx_iw_low, 0)
            idx_iw_hgh = min(idx_iw_hgh, iw)
            inputs_smpl_full = np.zeros((bs, kh, kw, ic))
            inputs_smpl_prnd = np.zeros((bs, kh, kw, ic))
            inputs_smpl_full[:, idx_sh_low:idx_sh_hgh, idx_sw_low:idx_sw_hgh, :] = \
                inputs_full[:, idx_ih_low:idx_ih_hgh, idx_iw_low:idx_iw_hgh, :]
            inputs_smpl_prnd[:, idx_sh_low:idx_sh_hgh, idx_sw_low:idx_sw_hgh, :] = \
                inputs_prnd[:, idx_ih_low:idx_ih_hgh, idx_iw_low:idx_iw_hgh, :]
            inputs_smpl_full_list += [inputs_smpl_full]
            inputs_smpl_prnd_list += [inputs_smpl_prnd]
            outputs_smpl_full_list += [
                np.reshape(outputs_full[:, idx_oh, idx_ow, :], [bs, -1])
            ]
            outputs_smpl_prnd_list += [
                np.reshape(outputs_prnd[:, idx_oh, idx_ow, :], [bs, -1])
            ]

        # concatenate samples into a single np.array
        inputs_smpl_full = np.concatenate(inputs_smpl_full_list, axis=0)
        inputs_smpl_prnd = np.concatenate(inputs_smpl_prnd_list, axis=0)
        outputs_smpl_full = np.vstack(outputs_smpl_full_list)
        outputs_smpl_prnd = np.vstack(outputs_smpl_prnd_list)

        # concatenate sampled inputs & outputs arrays
        inputs_smpl = [
            np.reshape(x, [-1, kh * kw])
            for x in np.split(inputs_smpl_prnd, ic, axis=3)
        ]
        outputs_smpl = outputs_smpl_full

        # validate inputs & outputs
        wei_mat_full = np.reshape(conv_krnl_full, [-1, oc])
        wei_mat_prnd = np.reshape(conv_krnl_prnd, [-1, oc])
        preds_smpl_full = np.matmul(
            np.reshape(inputs_smpl_full, [-1, kh * kw * ic]), wei_mat_full)
        preds_smpl_prnd = np.matmul(
            np.reshape(inputs_smpl_prnd, [-1, kh * kw * ic]), wei_mat_prnd)
        err_full = norm(outputs_smpl_full -
                        preds_smpl_full)**2 / outputs_smpl_full.size
        err_prnd = norm(outputs_smpl_prnd -
                        preds_smpl_prnd)**2 / outputs_smpl_prnd.size
        assert err_full < 1e-6, 'unable to recover output feature maps - full (%e)' % err_full
        assert err_prnd < 1e-6, 'unable to recover output feature maps - prnd (%e)' % err_prnd

        return inputs_smpl, outputs_smpl

    def __solve_sparse_regression(self, inputs_np_list, outputs_np, conv_krnl,
                                  prune_ratio):
        """Solve the sparsity-constrained regression problem.

        Args:
        * inputs_np_list: list of input feature maps (one per input channel, N x k^2)
        * outputs_np: output feature maps (N x c_o)
        * conv_krnl: initial convolutional kernel (k * k * c_i * c_o)
        * prune_ratio: pruning ratio

        Returns:
        * conv_krnl: updated convolutional kernel (k * k * c_i * c_o)
        """

        # obtain parameters
        bs = outputs_np.shape[0]
        kh, kw, ic, oc = conv_krnl.shape[0], conv_krnl.shape[
            1], conv_krnl.shape[2], conv_krnl.shape[3]
        nb_chns_nnz_target = int(ic * (1.0 - prune_ratio))
        tf.logging.info('[sparse regression]')
        tf.logging.info(
            '\tinputs: {} / outputs: {} / conv_krnl: {} / pr: {} / nnz: {}'.
            format(inputs_np_list[0].shape, outputs_np.shape, conv_krnl.shape,
                   prune_ratio, nb_chns_nnz_target))

        # compute the feature matrix & response vector
        tf.logging.info('computing the feature matrix & response vector')
        time_beg = timer()
        bs_rdc = int(math.ceil(min(bs, bs / oc * 10.0)))
        tf.logging.info('secondary sampling: %d -> %d' % (bs, bs_rdc))
        idxs_inst = np.random.choice(bs, size=(bs_rdc), replace=False)
        rspn_vec_np = np.reshape(outputs_np[idxs_inst],
                                 [-1, 1])  # N' x 1 (N' = N * c_o)
        feat_mat_np = np.zeros((ic, bs_rdc * oc))  # c_i x N'
        for idx in range(ic):
            wei_mat = np.reshape(conv_krnl[:, :, idx, :], [kh * kw, oc])
            feat_mat_np[idx] = np.matmul(inputs_np_list[idx][idxs_inst],
                                         wei_mat).ravel()
        feat_mat_np = np.transpose(feat_mat_np)
        tf.logging.info('time elapsed: %.4f (s)' % (timer() - time_beg))

        # compute <X^T * X> & <X^T * y> in advance
        tf.logging.info('computing <X^T * X> & <X^T * y> in advance')
        time_beg = timer()
        xt_x_np = np.matmul(feat_mat_np.T, feat_mat_np)
        xt_y_np = np.matmul(feat_mat_np.T, rspn_vec_np)
        xt_x_norm = norm(
            xt_x_np
        )  # normalize <xt_x> to unit norm, and adjust <xt_y> correspondingly
        xt_x_np /= xt_x_norm
        xt_y_np /= xt_x_norm
        mask_np_init = np.random.uniform(size=(ic, 1))
        tf.logging.info('time elapsed: %.4f (s)' % (timer() - time_beg))

        # solve the LASSO problem
        def __solve_lasso(x):
            self.sess_prune.run(self.meta_lasso['init_op'],
                                feed_dict={
                                    self.meta_lasso['xt_x_ph']: xt_x_np,
                                    self.meta_lasso['xt_y_ph']: xt_y_np,
                                    self.meta_lasso['mask_ph']: mask_np_init,
                                })
            for __ in range(FLAGS.cpr_ista_nb_iters):
                self.sess_prune.run(self.meta_lasso['train_op'],
                                    feed_dict={self.meta_lasso['gamma']: x})
            mask_np = self.sess_prune.run(self.meta_lasso['mask'])
            nb_chns_nnz = np.count_nonzero(mask_np)
            tf.logging.info('x = %e -> nb_chns_nnz = %d' % (x, nb_chns_nnz))
            return mask_np, nb_chns_nnz

        # determine <gamma>'s upper bound
        tf.logging.info('determining <gamma>\'s upper bound')
        time_beg = timer()
        ubnd = 0.1
        while True:
            mask_np, nb_chns_nnz = __solve_lasso(ubnd)
            if nb_chns_nnz <= nb_chns_nnz_target:
                break
            else:
                ubnd *= 2.0
        tf.logging.info('time elapsed: %.4f (s)' % (timer() - time_beg))

        # determine <gamma> via binary search
        tf.logging.info('determining <gamma> via binary search')
        time_beg = timer()
        lbnd = 0.0
        while nb_chns_nnz != nb_chns_nnz_target and ubnd - lbnd > 1e-8:
            val = (lbnd + ubnd) / 2.0
            mask_np, nb_chns_nnz = __solve_lasso(val)
            if nb_chns_nnz < nb_chns_nnz_target:
                ubnd = val
            elif nb_chns_nnz > nb_chns_nnz_target:
                lbnd = val
            else:
                break
        tf.logging.info('time elapsed: %.4f (s)' % (timer() - time_beg))

        # construct a least-square regression problem
        tf.logging.info('constructing a least-square regression problem')
        time_beg = timer()
        bnry_vec_np = (np.abs(mask_np) > 0.0).astype(np.float32)
        rspn_mat_np = outputs_np
        feat_tns_np = np.concatenate(
            [np.expand_dims(x, axis=-1) for x in inputs_np_list], axis=-1)
        feat_mat_np = np.reshape(
            feat_tns_np * np.reshape(bnry_vec_np, [1, 1, -1]), [bs, -1])
        w_mat_np_init = np.reshape(conv_krnl, [-1, oc])
        gacc1_np = np.zeros_like(w_mat_np_init)
        gacc2_np = np.zeros_like(w_mat_np_init)
        self.sess_prune.run(self.meta_lstsq['init_op'],
                            feed_dict={
                                self.meta_lstsq['x_mat_ph']: feat_mat_np,
                                self.meta_lstsq['y_mat_ph']: rspn_mat_np,
                                self.meta_lstsq['w_mat_ph']: w_mat_np_init,
                                self.meta_lstsq['gacc1_ph']: gacc1_np,
                                self.meta_lstsq['gacc2_ph']: gacc2_np,
                            })
        loss_reg, loss_dcy = self.sess_prune.run(
            [self.meta_lstsq['loss_reg'], self.meta_lstsq['loss_dcy']])
        tf.logging.info('losses: %e (reg) / %e (dcy)' % (loss_reg, loss_dcy))
        for __ in range(FLAGS.cpr_lstsq_nb_iters):
            self.sess_prune.run(self.meta_lstsq['train_op'])
        w_mat_np, loss_reg, loss_dcy = self.sess_prune.run([
            self.meta_lstsq['w_mat'], self.meta_lstsq['loss_reg'],
            self.meta_lstsq['loss_dcy']
        ])
        tf.logging.info('losses: %e (reg) / %e (dcy)' % (loss_reg, loss_dcy))
        conv_krnl = np.reshape(w_mat_np, conv_krnl.shape) * np.reshape(
            bnry_vec_np, [1, 1, -1, 1])
        tf.logging.info('time elapsed: %.4f (s)' % (timer() - time_beg))

        return conv_krnl

    def __save_model(self, is_train):
        """Save the current model for training or evaluation.

        Args:
        * is_train: whether to save a model for training
        """

        if is_train:
            save_path = self.saver_prnd_train.save(self.sess_train,
                                                   FLAGS.cpr_save_path,
                                                   self.global_step)
        else:
            save_path = self.saver_prnd_eval.save(self.sess_eval,
                                                  FLAGS.cpr_save_path_eval)
        tf.logging.info('model saved to ' + save_path)

    def __restore_model(self, is_train):
        """Restore a model from the latest checkpoint files.

        Args:
        * is_train: whether to restore a model for training
        """

        save_path = tf.train.latest_checkpoint(
            os.path.dirname(FLAGS.cpr_save_path))
        if is_train:
            self.saver_prnd_train.restore(self.sess_train, save_path)
        else:
            self.saver_prnd_eval.restore(self.sess_eval, save_path)
        tf.logging.info('model restored from ' + save_path)

    def __monitor_progress(self, summary, log_rslt, idx_iter, time_step):
        """Monitor the training progress.

        Args:
        * summary: summary protocol buffer
        * log_rslt: logging operations' results
        * idx_iter: index of the training iteration
        * time_step: time step between two summary operations
        """

        # write summaries for TensorBoard visualization
        self.sm_writer.add_summary(summary, idx_iter)

        # compute the training speed
        speed = FLAGS.batch_size * FLAGS.summ_step / time_step
        if FLAGS.enbl_multi_gpu:
            speed *= mgw.size()

        # display monitored statistics
        log_str = ' | '.join([
            '%s = %.4e' % (name, value)
            for name, value in zip(self.log_op_names, log_rslt)
        ])
        tf.logging.info('iter #%d: %s | speed = %.2f pics / sec' %
                        (idx_iter + 1, log_str, speed))
예제 #10
0
class FullPrecLearner(AbstractLearner):  # pylint: disable=too-many-instance-attributes
    """Full-precision learner (no model compression applied)."""
    def __init__(self,
                 sm_writer,
                 model_helper,
                 model_scope=None,
                 enbl_dst=None):
        """Constructor function.

        Args:
        * sm_writer: TensorFlow's summary writer
        * model_helper: model helper with definitions of model & dataset
        * model_scope: name scope in which to define the model
        * enbl_dst: whether to create a model with distillation loss
        """

        # class-independent initialization
        super(FullPrecLearner, self).__init__(sm_writer, model_helper)
        model_scope = 'quan_model'
        # over-ride the model scope and distillation loss switch
        if model_scope is not None:
            self.model_scope = model_scope
        self.enbl_dst = enbl_dst if enbl_dst is not None else FLAGS.enbl_dst

        # class-dependent initialization
        if self.enbl_dst:
            self.helper_dst = DistillationHelper(sm_writer, model_helper,
                                                 self.mpi_comm)
        self.__build(is_train=True)
        self.__build(is_train=False)

    def train(self):
        """Train a model and periodically produce checkpoint files."""

        # initialization
        self.sess_train.run(self.init_op)
        if FLAGS.enbl_multi_gpu:
            self.sess_train.run(self.bcast_op)

        # train the model through iterations and periodically save & evaluate the model
        time_prev = timer()
        for idx_iter in range(self.nb_iters_train):
            # train the model
            if (idx_iter + 1) % FLAGS.summ_step != 0:
                self.sess_train.run(self.train_op)
            else:
                __, summary, log_rslt = self.sess_train.run(
                    [self.train_op, self.summary_op, self.log_op])
                if self.is_primary_worker('global'):
                    time_step = timer() - time_prev
                    self.__monitor_progress(summary, log_rslt, idx_iter,
                                            time_step)
                    time_prev = timer()

            # save & evaluate the model at certain steps
            if self.is_primary_worker('global') and (idx_iter +
                                                     1) % FLAGS.save_step == 0:
                self.__save_model(is_train=True)
                self.evaluate()

        # save the final model
        if self.is_primary_worker('global'):
            self.__save_model(is_train=True)
            self.__restore_model(is_train=False)
            self.__save_model(is_train=False)
            self.evaluate()

    def evaluate(self):
        """Restore a model from the latest checkpoint files and then evaluate it."""

        self.__restore_model(is_train=False)

        if FLAGS.factory_mode:
            tmp_image = scipy.misc.imread(FLAGS.data_dir_local + "/images/" +
                                          FLAGS.image_name)
            x, y, z = tmp_image.shape
            print(tmp_image.shape)
            size_low = FLAGS.input_size
            size_high = FLAGS.sr_scale * size_low

            coordx = x // size_low
            coordy = y // size_low
            nb_iters = int(
                np.ceil(float(coordy * coordx) / FLAGS.batch_size_eval))
            outputs = []
            # outputs_bic = []
            image = np.zeros([size_high * coordx, size_high * coordy, 3],
                             dtype=np.uint8)
            # image_bic = np.zeros([size_high * coordx, size_high * coordy, 3], dtype=np.uint8)
            print(image.shape)
            print(nb_iters)
            for i in range(nb_iters):

                output = self.sess_eval.run(self.factory_op)
                for img in output[0]:
                    outputs.append(img)

            print(np.array(outputs).shape)
            index = 0
            for i in range(coordx):
                for j in range(coordy):
                    image[i * size_high:(i + 1) * size_high,
                          j * size_high:(j + 1) * size_high, :] = np.array(
                              outputs[index])
                    index += 1

            out = Image.fromarray(image, 'RGB')
            out.save('out_example/' + 'output.jpg')

            return

        nb_iters = int(
            np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval))
        eval_rslts = np.zeros((nb_iters, len(self.eval_op)))

        # print("nb_iters: ", nb_iters)

        for idx_iter in range(nb_iters):
            eval_rslts[idx_iter] = self.sess_eval.run(self.eval_op)

        # eval_psnr = sorted(eval_rslts[:, 1])
        # for idx in range(nb_iters):
        #   print(eval_psnr[idx])

        for idx, name in enumerate(self.eval_op_names):
            tf.logging.info('%s = %.4e' % (name, np.mean(eval_rslts[:, idx])))

        t = time.time()
        for idx_iter in range(nb_iters):
            _ = self.sess_eval.run(self.time_op)

        t = time.time() - t
        images, outputs, labels = self.sess_eval.run(self.out_op)
        # print(labels[0])
        output_size = FLAGS.sr_scale * FLAGS.input_size
        for i in range(min(8, FLAGS.batch_size_eval)):
            img_bic = scipy.misc.imresize(images[i],
                                          (output_size, output_size),
                                          'bicubic')
            img_bic = np.clip(img_bic, 0, 255)
            img_bic = np.array(img_bic, np.uint8)

            img_bic = Image.fromarray(img_bic, 'RGB')
            img = Image.fromarray(images[i], 'RGB')
            out = Image.fromarray(outputs[i], 'RGB')
            label = Image.fromarray(labels[i], 'RGB')
            img_bic.save(('out_example/' + str(i) + 'bic.jpg'))
            img.save('out_example/' + str(i) + 'image.jpg')
            out.save('out_example/' + str(i) + 'output.jpg')
            label.save('out_example/' + str(i) + 'label.jpg')

        tf.logging.info('time = %.4e' % (t / FLAGS.nb_smpls_eval))

        txt = open("log.txt", "a")
        l = ["full"]

        l += [self.model_name]
        # for idx, name in enumerate(self.eval_op_names):
        # tmp = np.mean(eval_rslts[:, 1])
        # l += ["PSNR: " + str(tmp)]
        for idx, name in enumerate(self.eval_op_names):
            tmp = np.mean(eval_rslts[:, idx])
            l += [name + ": " + str(tmp)]
        l += ["eval_batch_size: " + str(FLAGS.batch_size_eval)]
        l += ["time/pic: " + str(t / FLAGS.nb_smpls_eval)]

        txt.write(str(l))
        txt.write('\n')
        txt.close()

    def __build(self, is_train):  # pylint: disable=too-many-locals
        """Build the training / evaluation graph.

        Args:
        * is_train: whether to create the training graph
        """

        with tf.Graph().as_default():
            # TensorFlow session
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train(
                ) if is_train else self.build_dataset_eval()
                images, labels = iterator.get_next()
                tf.add_to_collection('images_final', images)

            # model definition - distilled model
            if self.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(sess, images)

            # model definition - primary model
            with tf.variable_scope(self.model_scope):
                # forward pass
                logits = self.forward_train(
                    images) if is_train else self.forward_eval(images)
                # out = self.forward_eval(images)
                tf.add_to_collection('logits_final', logits)

                # loss & extra evalution metrics
                loss, metrics = self.calc_loss(labels, logits,
                                               self.trainable_vars)
                if self.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits, logits_dst)
                tf.summary.scalar('loss', loss)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # optimizer & gradients
                if is_train:
                    self.global_step = tf.train.get_or_create_global_step()
                    lrn_rate, self.nb_iters_train = setup_lrn_rate(
                        self.global_step, self.model_name, self.dataset_name)
                    optimizer = tf.train.MomentumOptimizer(
                        lrn_rate, FLAGS.momentum)
                    # optimizer = tf.train.AdamOptimizer(lrn_rate)
                    if FLAGS.enbl_multi_gpu:
                        optimizer = mgw.DistributedOptimizer(optimizer)
                    grads = optimizer.compute_gradients(
                        loss, self.trainable_vars)

            # TF operations & model saver
            if is_train:
                self.sess_train = sess
                with tf.control_dependencies(self.update_ops):
                    self.train_op = optimizer.apply_gradients(
                        grads, global_step=self.global_step)
                self.summary_op = tf.summary.merge_all()
                self.log_op = [lrn_rate, loss] + list(metrics.values())
                self.log_op_names = ['lr', 'loss'] + list(metrics.keys())
                self.init_op = tf.variables_initializer(self.vars)
                if FLAGS.enbl_multi_gpu:
                    self.bcast_op = mgw.broadcast_global_variables(0)
                self.saver_train = tf.train.Saver(self.vars)
            else:
                self.sess_eval = sess

                self.factory_op = [tf.cast(logits, tf.uint8)]
                self.time_op = [logits]
                self.out_op = [
                    tf.cast(images, tf.uint8),
                    tf.cast(logits, tf.uint8),
                    tf.cast(labels, tf.uint8)
                ]
                self.eval_op = [loss] + list(metrics.values())
                self.eval_op_names = ['loss'] + list(metrics.keys())
                self.saver_eval = tf.train.Saver(self.vars)

    def __save_model(self, is_train):
        """Save the model to checkpoint files for training or evaluation.

        Args:
        * is_train: whether to save a model for training
        """

        if is_train:
            save_path = self.saver_train.save(self.sess_train, FLAGS.save_path,
                                              self.global_step)
        else:
            save_path = self.saver_eval.save(self.sess_eval,
                                             FLAGS.save_path_eval)
        tf.logging.info('model saved to ' + save_path)

    def __restore_model(self, is_train):
        """Restore a model from the latest checkpoint files.

        Args:
        * is_train: whether to restore a model for training
        """

        save_path = tf.train.latest_checkpoint(os.path.dirname(
            FLAGS.save_path))
        if is_train:
            self.saver_train.restore(self.sess_train, save_path)
        else:
            self.saver_eval.restore(self.sess_eval, save_path)
        tf.logging.info('model restored from ' + save_path)

    def __monitor_progress(self, summary, log_rslt, idx_iter, time_step):
        """Monitor the training progress.

        Args:
        * summary: summary protocol buffer
        * log_rslt: logging operations' results
        * idx_iter: index of the training iteration
        * time_step: time step between two summary operations
        """

        # write summaries for TensorBoard visualization
        self.sm_writer.add_summary(summary, idx_iter)

        # compute the training speed
        speed = FLAGS.batch_size * FLAGS.summ_step / time_step
        if FLAGS.enbl_multi_gpu:
            speed *= mgw.size()

        # display monitored statistics
        log_str = ' | '.join([
            '%s = %.4e' % (name, value)
            for name, value in zip(self.log_op_names, log_rslt)
        ])
        tf.logging.info('iter #%d: %s | speed = %.2f pics / sec' %
                        (idx_iter + 1, log_str, speed))
예제 #11
0
class UniformQuantLearner(AbstractLearner):
    # pylint: disable=too-many-instance-attributes
    """
    Uniform quantization for weights and activations
    """
    def __init__(self, sm_writer, model_helper):
        # class-independent initialization
        super(UniformQuantLearner, self).__init__(sm_writer, model_helper)

        # class-dependent initialization
        if FLAGS.enbl_dst:
            self.helper_dst = DistillationHelper(sm_writer, model_helper,
                                                 self.mpi_comm)

        # initialize class attributes
        self.ops = {}
        self.bit_placeholders = {}
        self.statistics = {}

        self.__build_train()  # for train
        self.__build_eval()  # for eval

        if self.is_primary_worker('local'):
            self.download_model()  # pre-trained model is required
        self.auto_barrier()

        # determine the optimal policy.
        bit_optimizer = BitOptimizer(self.dataset_name, self.weights,
                                     self.statistics, self.bit_placeholders,
                                     self.ops, self.layerwise_tune_list,
                                     self.sess_train, self.sess_eval,
                                     self.saver_train, self.saver_eval,
                                     self.auto_barrier)
        self.optimal_w_bit_list, self.optimal_a_bit_list = bit_optimizer.run()
        self.auto_barrier()

    def train(self):
        # initialization
        self.sess_train.run(self.ops['init'])
        # mgw_size = int(mgw.size()) if FLAGS.enbl_multi_gpu else 1

        total_iters = self.finetune_steps
        if FLAGS.enbl_warm_start:
            self.__restore_model(
                is_train=True)  # use the latest model for warm start

        self.auto_barrier()

        if FLAGS.enbl_multi_gpu:
            self.sess_train.run(self.ops['bcast'])

        time_prev = timer()
        # build the quantization bits
        feed_dict = {
            self.bit_placeholders['w_train']: self.optimal_w_bit_list,
            self.bit_placeholders['a_train']: self.optimal_a_bit_list
        }

        for idx_iter in range(total_iters):
            # train the model
            if (idx_iter + 1) % FLAGS.summ_step != 0:
                self.sess_train.run(self.ops['train'], feed_dict=feed_dict)
            else:
                _, summary, log_rslt = self.sess_train.run(
                    [self.ops['train'], self.ops['summary'], self.ops['log']],
                    feed_dict=feed_dict)
                time_prev = self.__monitor_progress(summary, log_rslt,
                                                    time_prev, idx_iter)

            # save & evaluate the model at certain steps
            if (idx_iter + 1) % FLAGS.save_step == 0:
                self.__save_model()
                self.evaluate()
                self.auto_barrier()

        # save the final model
        self.__save_model()
        self.evaluate()

    def evaluate(self):
        # early break for non-primary workers
        if not self.is_primary_worker(): return
        is_openpose = self.dataset_name == 'coco2017-pose'

        # evaluate the model
        self.__restore_model(is_train=False)
        losses, accuracies = [], []
        nb_iters = int(
            np.ceil(float(FLAGS.nb_smpls_eval) / FLAGS.batch_size_eval))

        # build the quantization bits
        feed_dict = {
            self.bit_placeholders['w_eval']: self.optimal_w_bit_list,
            self.bit_placeholders['a_eval']: self.optimal_a_bit_list
        }

        for _ in range(nb_iters):
            eval_rslt = self.sess_eval.run(self.ops['eval'],
                                           feed_dict=feed_dict)
            losses.append(eval_rslt[0])
            accuracies.append(eval_rslt[1])

        tf.logging.info('loss: {}'.format(np.mean(np.array(losses))))
        if not is_openpose:
            tf.logging.info('accuracy: {}'.format(np.mean(
                np.array(accuracies))))
        tf.logging.info("Optimal Weight Quantization:{}".format(
            self.optimal_w_bit_list))

        if FLAGS.uql_use_buckets:
            bucket_storage = self.sess_eval.run(self.ops['bucket_storage'],
                                                feed_dict=feed_dict)
            self.__show_bucket_storage(bucket_storage)

        if is_openpose and FLAGS.calculate_map:
            from examples.openpose_eval_helper import calculate_map
            tensor_image = self.sess_eval.graph.get_tensor_by_name(
                'model/MobilenetV2/input:0')
            tensor_output = self.sess_eval.graph.get_tensor_by_name(
                'model/Openpose/concat_stage7:0')
            calculate_map(lambda img: self.sess_eval.run([tensor_output],
                                                         feed_dict={
                                                             tensor_image: img,
                                                             **feed_dict
                                                         })[0])

    def __build_train(self):
        with tf.Graph().as_default():
            # TensorFlow session
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)
            self.sess_train = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()
                images.set_shape((FLAGS.batch_size, images.shape[1],
                                  images.shape[2], images.shape[3]))

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(
                    self.sess_train, images)

            # model definition
            with tf.variable_scope(self.model_scope, reuse=tf.AUTO_REUSE):
                # forward pass
                logits = self.forward_train(images)

                self.weights = [
                    v for v in self.trainable_vars
                    if 'kernel' in v.name or 'weight' in v.name
                ]
                if not FLAGS.uql_quantize_all_layers:
                    self.weights = self.weights[1:-1]
                self.statistics['num_weights'] = \
                    [tf.reshape(v, [-1]).shape[0].value for v in self.weights]

                self.__quantize_train_graph()

                # loss & accuracy
                loss, metrics = self.calc_loss(labels, logits,
                                               self.trainable_vars)
                if self.dataset_name == 'cifar_10':
                    acc_top1, acc_top5 = metrics['accuracy'], tf.constant(0.)
                elif self.dataset_name == 'ilsvrc_12':
                    acc_top1, acc_top5 = metrics['acc_top1'], metrics[
                        'acc_top5']
                elif self.dataset_name == 'coco2017-pose':
                    total_loss = metrics['total_loss_all_layers']
                    total_loss_ll_paf = metrics['total_loss_last_layer_paf']
                    total_loss_ll_heat = metrics['total_loss_last_layer_heat']
                    total_loss_ll = metrics['total_loss_last_layer']
                else:
                    raise ValueError("Unrecognized dataset name")

                model_loss = loss
                if FLAGS.enbl_dst:
                    dst_loss = self.helper_dst.calc_loss(logits, logits_dst)
                    loss += dst_loss
                    tf.summary.scalar('dst_loss', dst_loss)
                tf.summary.scalar('model_loss', model_loss)
                tf.summary.scalar('loss', loss)
                if self.dataset_name == 'coco2017-pose':
                    tf.summary.scalar('total_loss_all_layers', total_loss)
                    tf.summary.scalar('total_loss_last_layer_paf',
                                      total_loss_ll_paf)
                    tf.summary.scalar('total_loss_last_layer_heat',
                                      total_loss_ll_heat)
                    tf.summary.scalar('total_loss_last_layer', total_loss_ll)
                else:
                    tf.summary.scalar('acc_top1', acc_top1)
                    tf.summary.scalar('acc_top5', acc_top5)

                self.saver_train = tf.train.Saver(self.vars)

                self.ft_step = tf.get_variable('finetune_step',
                                               shape=[],
                                               dtype=tf.int32,
                                               trainable=False)

            # optimizer & gradients
            init_lr, bnds, decay_rates, self.finetune_steps = setup_bnds_decay_rates(
                self.model_name, self.dataset_name)
            lrn_rate = tf.train.piecewise_constant(
                self.ft_step, [i for i in bnds],
                [init_lr * decay_rate for decay_rate in decay_rates])

            # optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
            optimizer = tf.train.AdamOptimizer(learning_rate=lrn_rate)
            if FLAGS.enbl_multi_gpu:
                optimizer = mgw.DistributedOptimizer(optimizer)
            grads = optimizer.compute_gradients(loss, self.trainable_vars)

            # sm write graph
            self.sm_writer.add_graph(self.sess_train.graph)

            with tf.control_dependencies(self.update_ops):
                self.ops['train'] = optimizer.apply_gradients(
                    grads, global_step=self.ft_step)
            self.ops['summary'] = tf.summary.merge_all()

            if FLAGS.enbl_dst:
                self.ops['log'] = [
                    lrn_rate,
                    dst_loss,
                    model_loss,
                    loss,
                ]
            else:
                self.ops['log'] = [
                    lrn_rate,
                    model_loss,
                    loss,
                ]

            if self.dataset_name == 'coco2017-pose':
                self.ops['log'] += [
                    total_loss, total_loss_ll_paf, total_loss_ll_heat,
                    total_loss_ll
                ]
            else:
                self.ops['log'] += [acc_top1, acc_top5]

            self.ops['reset_ft_step'] = tf.assign(
                self.ft_step, tf.constant(0, dtype=tf.int32))
            self.ops['init'] = tf.global_variables_initializer()
            self.ops['bcast'] = mgw.broadcast_global_variables(
                0) if FLAGS.enbl_multi_gpu else None
            self.saver_quant = tf.train.Saver(self.vars)

    def __build_eval(self):
        with tf.Graph().as_default():
            # TensorFlow session
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)
            self.sess_eval = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_eval()
                images, labels = iterator.get_next()
                # images.set_shape((FLAGS.batch_size, images.shape[1], images.shape[2], images.shape[3]))
                images.set_shape((FLAGS.batch_size_eval, images.shape[1],
                                  images.shape[2], images.shape[3]))
                self.images_eval = images

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(
                    self.sess_eval, images)

            # model definition
            with tf.variable_scope(self.model_scope, reuse=tf.AUTO_REUSE):
                # forward pass
                logits = self.forward_eval(images)

                self.__quantize_eval_graph()

                # loss & accuracy
                loss, metrics = self.calc_loss(labels, logits,
                                               self.trainable_vars)
                if self.dataset_name == 'cifar_10':
                    acc_top1, acc_top5 = metrics['accuracy'], tf.constant(0.)
                elif self.dataset_name == 'ilsvrc_12':
                    acc_top1, acc_top5 = metrics['acc_top1'], metrics[
                        'acc_top5']
                elif self.dataset_name == 'coco2017-pose':
                    total_loss = metrics['total_loss_all_layers']
                    total_loss_ll_paf = metrics['total_loss_last_layer_paf']
                    total_loss_ll_heat = metrics['total_loss_last_layer_heat']
                    total_loss_ll = metrics['total_loss_last_layer']
                else:
                    raise ValueError("Unrecognized dataset name")

                if FLAGS.enbl_dst:
                    dst_loss = self.helper_dst.calc_loss(logits, logits_dst)
                    loss += dst_loss

                # TF operations & model saver
                if self.dataset_name == 'coco2017-pose':
                    self.ops['eval'] = [
                        loss, total_loss, total_loss_ll_paf,
                        total_loss_ll_heat, total_loss_ll
                    ]
                else:
                    self.ops['eval'] = [loss, acc_top1, acc_top5]
                self.saver_eval = tf.train.Saver(self.vars)

    def __quantize_train_graph(self):
        """ Insert quantization nodes to the training graph. """
        uni_quant = UniformQuantization(self.sess_train, FLAGS.uql_bucket_size,
                                        FLAGS.uql_use_buckets,
                                        FLAGS.uql_bucket_type)

        # Find Conv2d Op
        matmul_ops = uni_quant.search_matmul_op(FLAGS.uql_quantize_all_layers)
        act_ops = uni_quant.search_activation_op()

        self.statistics['nb_matmuls'] = len(matmul_ops)
        self.statistics['nb_activations'] = len(act_ops)

        # Replace Conv2d Op with quantized weights
        matmul_op_names = [op.name for op in matmul_ops]
        act_op_names = [op.name for op in act_ops]

        # build the placeholder for
        self.bit_placeholders['w_train'] = tf.placeholder(
            tf.int64, shape=[self.statistics['nb_matmuls']], name="w_bit_list")
        self.bit_placeholders['a_train'] = tf.placeholder(
            tf.int64,
            shape=[self.statistics['nb_activations']],
            name="a_bit_list")
        w_bit_dict_train = self.__build_quant_dict(
            matmul_op_names, self.bit_placeholders['w_train'])
        a_bit_dict_train = self.__build_quant_dict(
            act_op_names, self.bit_placeholders['a_train'])

        uni_quant.insert_quant_op_for_weights(w_bit_dict_train)
        uni_quant.insert_quant_op_for_activations(a_bit_dict_train)

        # add layerwise finetuning. TODO: working not very well
        self.layerwise_tune_list = uni_quant.get_layerwise_tune_op(self.weights) \
            if FLAGS.uql_enbl_rl_layerwise_tune else (None, None)

    def __quantize_eval_graph(self):
        """ Insert quantization nodes to the evaluation graph. """
        uni_quant = UniformQuantization(self.sess_eval, FLAGS.uql_bucket_size,
                                        FLAGS.uql_use_buckets,
                                        FLAGS.uql_bucket_type)
        # Find matmul ops
        matmul_ops = uni_quant.search_matmul_op(FLAGS.uql_quantize_all_layers)
        act_ops = uni_quant.search_activation_op()
        assert self.statistics['nb_matmuls'] == len(matmul_ops), \
            'the length of matmul_ops on train and eval graphs does not match'
        assert self.statistics['nb_activations'] == len(act_ops), \
            'the length of act_ops on train and eval graphs does not match'

        # Replace Conv2d Op with quantized weights
        matmul_op_names = [op.name for op in matmul_ops]
        act_op_names = [op.name for op in act_ops]

        # build the placeholder for eval
        self.bit_placeholders['w_eval'] = tf.placeholder(
            tf.int64, shape=[self.statistics['nb_matmuls']], name="w_bit_list")
        self.bit_placeholders['a_eval'] = tf.placeholder(
            tf.int64,
            shape=[self.statistics['nb_activations']],
            name="a_bit_list")

        w_bit_dict_eval = self.__build_quant_dict(
            matmul_op_names, self.bit_placeholders['w_eval'])
        a_bit_dict_eval = self.__build_quant_dict(
            act_op_names, self.bit_placeholders['a_eval'])

        uni_quant.insert_quant_op_for_weights(w_bit_dict_eval)
        uni_quant.insert_quant_op_for_activations(a_bit_dict_eval)

        self.ops['bucket_storage'] = uni_quant.bucket_storage

    def __save_model(self):
        # early break for non-primary workers
        if not self.is_primary_worker():
            return

        save_quant_model_path = self.saver_quant.save(
            self.sess_train, FLAGS.uql_save_quant_model_path, self.ft_step)
        # tf.logging.info('full precision model saved to ' + save_path)
        tf.logging.info('quantized model saved to ' + save_quant_model_path)

    def __restore_model(self, is_train):
        if is_train:
            save_path = tf.train.latest_checkpoint(
                os.path.dirname(FLAGS.save_path))
            save_dir = os.path.dirname(save_path)
            for item in os.listdir(save_dir):
                print('Print directory: ' + item)
            self.saver_train.restore(self.sess_train, save_path)
        else:
            save_path = tf.train.latest_checkpoint(
                os.path.dirname(FLAGS.uql_save_quant_model_path))
            self.saver_eval.restore(self.sess_eval, save_path)
        tf.logging.info('model restored from ' + save_path)

    def __monitor_progress(self, summary, log_rslt, time_prev, idx_iter):
        # early break for non-primary workers
        if not self.is_primary_worker():
            return None

        # write summaries for TensorBoard visualization
        self.sm_writer.add_summary(summary, idx_iter)

        # display monitored statistics
        speed = FLAGS.batch_size * FLAGS.summ_step / (timer() - time_prev)
        if FLAGS.enbl_multi_gpu:
            speed *= mgw.size()

        # NOTE: for cifar-10, acc_top5 is 0.
        if self.dataset_name == 'coco2017-pose':
            if FLAGS.enbl_dst:
                lrn_rate, dst_loss, model_loss, loss, total_loss, total_loss_ll_paf, total_loss_ll_heat, total_loss_ll = log_rslt[:
                                                                                                                                  8]
                tf.logging.info(
                    'iter #%d: lr = %e | dst_loss = %.4f | model_loss = %.4f | loss = %.4f | ll_paf = %.4f | ll_heat = %.4f | ll = %.4f | speed = %.2f pics / sec'
                    % (idx_iter + 1, lrn_rate, dst_loss, model_loss, loss,
                       total_loss_ll_paf, total_loss_ll_heat, total_loss_ll,
                       speed))
            else:
                lrn_rate, model_loss, loss, total_loss, total_loss_ll_paf, total_loss_ll_heat, total_loss_ll = log_rslt[:
                                                                                                                        7]
                tf.logging.info(
                    'iter #%d: lr = %e | model_loss = %.4f | loss = %.4f | ll_paf = %.4f | ll_heat = %.4f | ll = %.4f | speed = %.2f pics / sec'
                    % (idx_iter + 1, lrn_rate, model_loss, loss,
                       total_loss_ll_paf, total_loss_ll_heat, total_loss_ll,
                       speed))
        else:
            if FLAGS.enbl_dst:
                lrn_rate, dst_loss, model_loss, loss, acc_top1, acc_top5 = log_rslt[:
                                                                                    6]
                tf.logging.info(
                    'iter #%d: lr = %e | dst_loss = %.4f | model_loss = %.4f | loss = %.4f | acc_top1 = %.4f | acc_top5 = %.4f | speed = %.2f pics / sec'
                    % (idx_iter + 1, lrn_rate, dst_loss, model_loss, loss,
                       acc_top1, acc_top5, speed))
            else:
                lrn_rate, model_loss, loss, acc_top1, acc_top5 = log_rslt[:5]
                tf.logging.info(
                    'iter #%d: lr = %e | model_loss = %.4f | loss = %.4f | acc_top1 = %.4f | acc_top5 = %.4f | speed = %.2f pics / sec'
                    % (idx_iter + 1, lrn_rate, model_loss, loss, acc_top1,
                       acc_top5, speed))

        return timer()

    def __show_bucket_storage(self, bucket_storage):
        # show the bucket storage and ratios
        weight_storage = sum(self.statistics['num_weights']) * FLAGS.uql_weight_bits \
            if not FLAGS.uql_enbl_rl_agent else sum(self.statistics['num_weights']) * FLAGS.uql_equivalent_bits
        tf.logging.info(
            'bucket storage: %d bit / %.3f kb | weight storage: %d bit / %.3f kb | ratio: %.3f'
            % (bucket_storage, bucket_storage /
               (8. * 1024.), weight_storage, weight_storage /
               (8. * 1024.), bucket_storage * 1. / weight_storage))

    @staticmethod
    def __build_quant_dict(keys, values):
        """ Bind keys and values to dictionaries.

        Args:
        * keys: A list of op_names
        * values: A Tensor with len(op_names) elements

        Returns:
        * dict: (key, value) for weight name and quant bits respectively
        """

        dict_ = {}
        for (idx, v) in enumerate(keys):
            dict_[v] = values[idx]
        return dict_