def __build_eval(self):
        """Build the evaluation graph."""

        with tf.Graph().as_default() as graph:
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            config.gpu_options.allow_growth = True  # pylint: disable=no-member
            self.sess_eval = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_eval()
                images, labels = iterator.get_next()

            # model definition - uniform quantized model - part 1
            with tf.variable_scope(self.model_scope_quan):
                logits = self.forward_eval(images)
                if not isinstance(logits, dict):
                    outputs = tf.nn.softmax(logits)
                else:
                    outputs = tf.nn.softmax(logits['cls_pred'])
                tf.contrib.quantize.experimental_create_eval_graph(
                    weight_bits=FLAGS.uqtf_weight_bits,
                    activation_bits=FLAGS.uqtf_activation_bits,
                    scope=self.model_scope_quan)
                for node_name in self.unquant_node_names:
                    insert_quant_op(graph, node_name, is_train=False)
                vars_quan = get_vars_by_scope(self.model_scope_quan)

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(
                    self.sess_eval, images)

            # model definition - uniform quantized model -part 2
            with tf.variable_scope(self.model_scope_quan):
                # loss & extra evaluation metrics
                loss, metrics = self.calc_loss(labels, logits,
                                               vars_quan['trainable'])
                if FLAGS.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits, logits_dst)

                # TF operations for evaluation
                vars_quan = get_vars_by_scope(self.model_scope_quan)
                self.eval_op = [loss] + list(metrics.values())
                self.eval_op_names = ['loss'] + list(metrics.keys())
                self.outputs_eval = logits
                self.saver_quan_eval = tf.train.Saver(vars_quan['all'])

            # add input & output tensors to certain collections
            if not isinstance(images, dict):
                tf.add_to_collection('images_final', images)
            else:
                tf.add_to_collection('images_final', images['image'])
            if not isinstance(logits, dict):
                tf.add_to_collection('logits_final', logits)
            else:
                tf.add_to_collection('logits_final', logits['cls_pred'])
    def __build_eval(self, model_helper):
        """Build the evaluation graph for the 'optimal' protocol.

    Args:
    * model_helper: model helper with definitions of model & dataset
    """

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            self.sess_eval = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                __, iterator = model_helper.build_dataset_train(
                    enbl_trn_val_split=True)
                images, labels = iterator.get_next()

            # model definition - weight sparsified network
            with tf.variable_scope(self.model_scope_prnd):
                logits = model_helper.forward_eval(images)
                vars_prnd = get_vars_by_scope(self.model_scope_prnd)
                self.loss_eval, self.metrics_eval = \
                    model_helper.calc_loss(labels, logits, vars_prnd['trainable'])
                self.saver_prnd_eval = tf.train.Saver(vars_prnd['all'])
示例#3
0
    def __build_eval(self):
        with tf.Graph().as_default():
            # TensorFlow session
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)
            self.sess_eval = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_eval()
                images, labels = iterator.get_next()
                # images.set_shape((FLAGS.batch_size, images.shape[1], images.shape[2], images.shape[3]))
                images.set_shape((FLAGS.batch_size_eval, images.shape[1],
                                  images.shape[2], images.shape[3]))
                self.images_eval = images

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(
                    self.sess_eval, images)

            # model definition
            with tf.variable_scope(self.model_scope, reuse=tf.AUTO_REUSE):
                # forward pass
                logits = self.forward_eval(images)

                self.__quantize_eval_graph()

                # loss & accuracy
                loss, metrics = self.calc_loss(labels, logits,
                                               self.trainable_vars)
                if self.dataset_name == 'cifar_10':
                    acc_top1, acc_top5 = metrics['accuracy'], tf.constant(0.)
                elif self.dataset_name == 'ilsvrc_12':
                    acc_top1, acc_top5 = metrics['acc_top1'], metrics[
                        'acc_top5']
                elif self.dataset_name == 'coco2017-pose':
                    total_loss = metrics['total_loss_all_layers']
                    total_loss_ll_paf = metrics['total_loss_last_layer_paf']
                    total_loss_ll_heat = metrics['total_loss_last_layer_heat']
                    total_loss_ll = metrics['total_loss_last_layer']
                else:
                    raise ValueError("Unrecognized dataset name")

                if FLAGS.enbl_dst:
                    dst_loss = self.helper_dst.calc_loss(logits, logits_dst)
                    loss += dst_loss

                # TF operations & model saver
                if self.dataset_name == 'coco2017-pose':
                    self.ops['eval'] = [
                        loss, total_loss, total_loss_ll_paf,
                        total_loss_ll_heat, total_loss_ll
                    ]
                else:
                    self.ops['eval'] = [loss, acc_top1, acc_top5]
                self.saver_eval = tf.train.Saver(self.vars)
示例#4
0
    def __build_eval(self):
        """Build the evaluation graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            self.sess_eval = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_eval()
                images, labels = iterator.get_next()

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(
                    self.sess_eval, images)

            # model definition - channel-pruned model
            with tf.variable_scope(self.model_scope_prnd):
                # loss & extra evaluation metrics
                logits = self.forward_eval(images)
                vars_prnd = get_vars_by_scope(self.model_scope_prnd)
                loss, metrics = self.calc_loss(labels, logits,
                                               vars_prnd['trainable'])
                if FLAGS.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits, logits_dst)

                # overall pruning ratios of trainable & maskable variables
                pr_trainable = calc_prune_ratio(vars_prnd['trainable'])
                pr_maskable = calc_prune_ratio(vars_prnd['maskable'])

                # TF operations for evaluation
                self.factory_op = [tf.cast(logits, tf.uint8)]
                self.time_op = [logits]
                self.out_op = [
                    tf.cast(images, tf.uint8),
                    tf.cast(logits, tf.uint8),
                    tf.cast(labels, tf.uint8)
                ]
                self.eval_op = [loss, pr_trainable, pr_maskable] + list(
                    metrics.values())
                self.eval_op_names = ['loss', 'pr_trn', 'pr_msk'] + list(
                    metrics.keys())
                self.saver_prnd_eval = tf.train.Saver(vars_prnd['all'])

            # add input & output tensors to certain collections
            tf.add_to_collection('images_final', images)
            tf.add_to_collection('logits_final', logits)
示例#5
0
文件: utils.py 项目: rhkdqo93/AC
def create_session():
  """Create a TensorFlow session.

  Return:
  * sess: TensorFlow session
  """

  # create a TensorFlow session
  config = tf.ConfigProto()
  config.gpu_options.visible_device_list = str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
  config.gpu_options.allow_growth = True  # pylint: disable=no-member
  sess = tf.Session(config=config)

  return sess
示例#6
0
文件: misc_utils.py 项目: rhkdqo93/AC
def is_primary_worker(scope='global'):
    """Check whether is the primary worker of all nodes (global) or the current node (local).

  Args:
  * scope: check scope ('global' OR 'local')

  Returns:
  * flag: whether is the primary worker
  """

    if scope == 'global':
        return True if not FLAGS.enbl_multi_gpu else mgw.rank() == 0
    elif scope == 'local':
        return True if not FLAGS.enbl_multi_gpu else mgw.local_rank() == 0
    else:
        raise ValueError('unrecognized worker scope: ' + scope)
示例#7
0
    def __build_eval(self):
        """Build the evaluation graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            if FLAGS.enbl_multi_gpu:
                config.gpu_options.visible_device_list = str(mgw.local_rank())  # pylint: disable=no-member
            else:
                config.gpu_options.visible_device_list = '0'  # pylint: disable=no-member
            self.sess_eval = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_eval()
                images, labels = iterator.get_next()

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(
                    self.sess_eval, images)

            # model definition - weight-sparsified model
            with tf.variable_scope(self.model_scope):
                # loss & extra evaluation metrics
                logits = self.forward_eval(images)
                loss, metrics = self.calc_loss(labels, logits,
                                               self.trainable_vars)
                if FLAGS.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits, logits_dst)

                # overall pruning ratios of trainable & maskable variables
                pr_trainable = calc_prune_ratio(self.trainable_vars)
                pr_maskable = calc_prune_ratio(self.maskable_vars)

                # TF operations for evaluation
                self.eval_op = [loss, pr_trainable, pr_maskable] + list(
                    metrics.values())
                self.eval_op_names = ['loss', 'pr_trn', 'pr_msk'] + list(
                    metrics.keys())
                self.saver_eval = tf.train.Saver(self.vars)
示例#8
0
    def __build_pruned_evaluate_model(self, path=None):
        ''' build a evaluation model from pruned model '''
        # early break for non-primary workers
        if not self.__is_primary_worker():
            return

        if path is None:
            path = FLAGS.save_path

        if not tf.train.checkpoint_exists(path):
            return

        with tf.Graph().as_default():
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(  # pylint: disable=no-member
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)
            self.sess_eval = tf.Session(config=config)
            self.saver_eval = tf.train.import_meta_graph(path + '.meta')
            self.saver_eval.restore(self.sess_eval, path)
            eval_logits = tf.get_collection('logits')[0]
            tf.add_to_collection('logits_final', eval_logits)
            eval_images = tf.get_collection('eval_images')[0]
            tf.add_to_collection('images_final', eval_images)
            eval_labels = tf.get_collection('eval_labels')[0]
            mem_images = tf.get_collection('mem_images')[0]
            mem_labels = tf.get_collection('mem_labels')[0]

            self.sess_eval.close()

            graph_editor.reroute_ts(eval_images, mem_images)
            graph_editor.reroute_ts(eval_labels, mem_labels)

            self.sess_eval = tf.Session(config=config)
            self.saver_eval.restore(self.sess_eval, path)
            trainable_vars = self.trainable_vars
            loss, accuracy = self.calc_loss(eval_labels, eval_logits,
                                            trainable_vars)
            self.eval_op = [loss] + list(accuracy.values())
            self.sm_writer.add_graph(self.sess_eval.graph)
    def __build_minimal(self, model_helper):
        """Build the minimal graph for 'uniform' & 'heurist' protocols.

    Args:
    * model_helper: model helper with definitions of model & dataset
    """

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            self.sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = model_helper.build_dataset_train()
                images, __ = iterator.get_next()

            # model definition - full-precision network
            with tf.variable_scope(self.model_scope_full):
                __ = model_helper.forward_eval(
                    images)  # DO NOT USE forward_train() HERE!!!
                self.vars_full = get_vars_by_scope(self.model_scope_full)
示例#10
0
    def __build_train(self):
        with tf.Graph().as_default():
            # TensorFlow session
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(mgw.local_rank() \
                if FLAGS.enbl_multi_gpu else 0)
            self.sess_train = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()
                images.set_shape((FLAGS.batch_size, images.shape[1],
                                  images.shape[2], images.shape[3]))

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(
                    self.sess_train, images)

            # model definition
            with tf.variable_scope(self.model_scope, reuse=tf.AUTO_REUSE):
                # forward pass
                logits = self.forward_train(images)

                self.weights = [
                    v for v in self.trainable_vars
                    if 'kernel' in v.name or 'weight' in v.name
                ]
                if not FLAGS.uql_quantize_all_layers:
                    self.weights = self.weights[1:-1]
                self.statistics['num_weights'] = \
                    [tf.reshape(v, [-1]).shape[0].value for v in self.weights]

                self.__quantize_train_graph()

                # loss & accuracy
                loss, metrics = self.calc_loss(labels, logits,
                                               self.trainable_vars)
                if self.dataset_name == 'cifar_10':
                    acc_top1, acc_top5 = metrics['accuracy'], tf.constant(0.)
                elif self.dataset_name == 'ilsvrc_12':
                    acc_top1, acc_top5 = metrics['acc_top1'], metrics[
                        'acc_top5']
                else:
                    raise ValueError("Unrecognized dataset name")

                model_loss = loss
                if FLAGS.enbl_dst:
                    dst_loss = self.helper_dst.calc_loss(logits, logits_dst)
                    loss += dst_loss
                    tf.summary.scalar('dst_loss', dst_loss)
                tf.summary.scalar('model_loss', model_loss)
                tf.summary.scalar('loss', loss)
                tf.summary.scalar('acc_top1', acc_top1)
                tf.summary.scalar('acc_top5', acc_top5)

                self.saver_train = tf.train.Saver(self.vars)

                self.ft_step = tf.get_variable('finetune_step',
                                               shape=[],
                                               dtype=tf.int32,
                                               trainable=False)

            # optimizer & gradients
            init_lr, bnds, decay_rates, self.finetune_steps = \
                setup_bnds_decay_rates(self.model_name, self.dataset_name)
            lrn_rate = tf.train.piecewise_constant(
                self.ft_step, [i for i in bnds],
                [init_lr * decay_rate for decay_rate in decay_rates])

            # optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
            optimizer = tf.train.AdamOptimizer(learning_rate=lrn_rate)
            if FLAGS.enbl_multi_gpu:
                optimizer = mgw.DistributedOptimizer(optimizer)
            grads = optimizer.compute_gradients(loss, self.trainable_vars)

            # sm write graph
            self.sm_writer.add_graph(self.sess_train.graph)

            with tf.control_dependencies(self.update_ops):
                self.ops['train'] = optimizer.apply_gradients(
                    grads, global_step=self.ft_step)
            self.ops['summary'] = tf.summary.merge_all()

            if FLAGS.enbl_dst:
                self.ops['log'] = [
                    lrn_rate, dst_loss, model_loss, loss, acc_top1, acc_top5
                ]
            else:
                self.ops['log'] = [
                    lrn_rate, model_loss, loss, acc_top1, acc_top5
                ]

            self.ops['reset_ft_step'] = tf.assign(
                self.ft_step, tf.constant(0, dtype=tf.int32))
            self.ops['init'] = tf.global_variables_initializer()
            self.ops['bcast'] = mgw.broadcast_global_variables(
                0) if FLAGS.enbl_multi_gpu else None
            self.saver_quant = tf.train.Saver(self.vars)
示例#11
0
def export_tflite_model(input_coll, output_coll, images_shape, images_name):
    """Export a *.tflite model from checkpoint files.

  Args:
  * input_coll: input collection's name
  * output_coll: output collection's name

  Returns:
  * unquant_node_name: unquantized activation node name (None if not found)
  """

    # remove previously generated *.pb & *.tflite models
    model_dir = os.path.dirname(FLAGS.uqtf_save_path_probe_eval)
    idx_worker = mgw.local_rank() if FLAGS.enbl_multi_gpu else 0
    pb_path = os.path.join(model_dir, 'model_%d.pb' % idx_worker)
    tflite_path = os.path.join(model_dir, 'model_%d.tflite' % idx_worker)
    if os.path.exists(pb_path):
        os.remove(pb_path)
    if os.path.exists(tflite_path):
        os.remove(tflite_path)

    # convert checkpoint files to a *.pb model
    images_name_ph = 'images'
    with tf.Graph().as_default() as graph:
        # create a TensorFlow session
        sess = create_session()

        # restore the graph with inputs replaced
        ckpt_path = tf.train.latest_checkpoint(model_dir)
        meta_path = ckpt_path + '.meta'
        images = tf.placeholder(tf.float32,
                                shape=images_shape,
                                name=images_name_ph)
        saver = tf.train.import_meta_graph(meta_path,
                                           input_map={images_name: images})
        saver.restore(sess, ckpt_path)

        # obtain input & output nodes
        net_inputs = tf.get_collection(input_coll)
        net_logits = tf.get_collection(output_coll)[0]
        net_outputs = [tf.nn.softmax(net_logits)]
        for node in net_inputs:
            tf.logging.info('inputs: {} / {}'.format(node.name, node.shape))
        for node in net_outputs:
            tf.logging.info('outputs: {} / {}'.format(node.name, node.shape))

        # write the original grpah to *.pb file
        graph_def = tf.graph_util.convert_variables_to_constants(
            sess, graph.as_graph_def(),
            [node.name.replace(':0', '') for node in net_outputs])
        tf.train.write_graph(graph_def,
                             model_dir,
                             os.path.basename(pb_path),
                             as_text=False)
        assert os.path.exists(pb_path), 'failed to generate a *.pb model'

    # convert the *.pb model to a *.tflite model and detect the unquantized activation node (if any)
    tf.logging.info(pb_path + ' -> ' + tflite_path)
    converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph(
        pb_path, [images_name_ph],
        [node.name.replace(':0', '') for node in net_outputs])
    converter.inference_type = tf.lite.constants.QUANTIZED_UINT8  #lite_constants.QUANTIZED_UINT8
    converter.quantized_input_stats = {images_name_ph: (0., 1.)}
    unquant_node_name = None
    try:
        tflite_model = converter.convert()
        with open(tflite_path, 'wb') as o_file:
            o_file.write(tflite_model)
    except Exception as err:
        err_msg = str(err)
        flag_str = 'tensorflow/contrib/lite/toco/tooling_util.cc:1634]'
        for sub_line in err_msg.split('\\n'):
            if flag_str in sub_line:
                sub_strs = sub_line.replace(',', ' ').split()
                unquant_node_name = sub_strs[sub_strs.index(flag_str) +
                                             2] + ':0'
                break
        assert unquant_node_name is not None, 'unable to locate the unquantized node'

    return unquant_node_name
示例#12
0
  def __build(self, is_train):  # pylint: disable=too-many-locals
    """Build the training / evaluation graph.

    Args:
    * is_train: whether to create the training graph
    """

    with tf.Graph().as_default():
      # TensorFlow session
      config = tf.ConfigProto()
      config.gpu_options.visible_device_list = str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
      sess = tf.Session(config=config)

      # data input pipeline
      with tf.variable_scope(self.data_scope):
        iterator = self.build_dataset_train() if is_train else self.build_dataset_eval()
        images, labels = iterator.get_next()
        tf.add_to_collection('images_final', images)

      # model definition - distilled model
      if self.enbl_dst:
        logits_dst = self.helper_dst.calc_logits(sess, images)

      # model definition - primary model
      with tf.variable_scope(self.model_scope):
        # forward pass
        logits = self.forward_train(images) if is_train else self.forward_eval(images)
        tf.add_to_collection('logits_final', logits)

        # loss & extra evalution metrics
        loss, metrics = self.calc_loss(labels, logits, self.trainable_vars)
        if self.enbl_dst:
          loss += self.helper_dst.calc_loss(logits, logits_dst)
        tf.summary.scalar('loss', loss)
        for key, value in metrics.items():
          tf.summary.scalar(key, value)

        # optimizer & gradients
        if is_train:
          self.global_step = tf.train.get_or_create_global_step()
          lrn_rate, self.nb_iters_train = self.setup_lrn_rate(self.global_step)
          optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
          if FLAGS.enbl_multi_gpu:
            optimizer = mgw.DistributedOptimizer(optimizer)
          grads = optimizer.compute_gradients(loss, self.trainable_vars)

      # TF operations & model saver
      if is_train:
        self.sess_train = sess
        with tf.control_dependencies(self.update_ops):
          self.train_op = optimizer.apply_gradients(grads, global_step=self.global_step)
        self.summary_op = tf.summary.merge_all()
        self.log_op = [lrn_rate, loss] + list(metrics.values())
        self.log_op_names = ['lr', 'loss'] + list(metrics.keys())
        self.init_op = tf.variables_initializer(self.vars)
        if FLAGS.enbl_multi_gpu:
          self.bcast_op = mgw.broadcast_global_variables(0)
        self.saver_train = tf.train.Saver(self.vars)
      else:
        self.sess_eval = sess
        self.eval_op = [loss] + list(metrics.values())
        self.eval_op_names = ['loss'] + list(metrics.keys())
        self.outputs_eval = logits
        self.saver_eval = tf.train.Saver(self.vars)
示例#13
0
  def __build_train(self):  # pylint: disable=too-many-locals,too-many-statements
    """Build the training graph."""

    with tf.Graph().as_default():
      # create a TF session for the current graph
      config = tf.ConfigProto()
      config.gpu_options.visible_device_list = str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
      sess = tf.Session(config=config)

      # data input pipeline
      with tf.variable_scope(self.data_scope):
        iterator = self.build_dataset_train()
        images, labels = iterator.get_next()

      # model definition - distilled model
      if FLAGS.enbl_dst:
        logits_dst = self.helper_dst.calc_logits(sess, images)

      # model definition - full model
      with tf.variable_scope(self.model_scope_full):
        __ = self.forward_train(images)
        self.vars_full = get_vars_by_scope(self.model_scope_full)
        self.saver_full = tf.train.Saver(self.vars_full['all'])
        self.save_path_full = FLAGS.save_path

      # model definition - channel-pruned model
      with tf.variable_scope(self.model_scope_prnd):
        logits_prnd = self.forward_train(images)
        self.vars_prnd = get_vars_by_scope(self.model_scope_prnd)
        self.maskable_var_names = [var.name for var in self.vars_prnd['maskable']]
        self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all'])

        # loss & extra evaluation metrics
        loss_bsc, metrics = self.calc_loss(labels, logits_prnd, self.vars_prnd['trainable'])
        if not FLAGS.enbl_dst:
          loss_fnl = loss_bsc
        else:
          loss_fnl = loss_bsc + self.helper_dst.calc_loss(logits_prnd, logits_dst)
        tf.summary.scalar('loss_bsc', loss_bsc)
        tf.summary.scalar('loss_fnl', loss_fnl)
        for key, value in metrics.items():
          tf.summary.scalar(key, value)

        # learning rate schedule
        self.global_step = tf.train.get_or_create_global_step()
        lrn_rate, self.nb_iters_train = self.setup_lrn_rate(self.global_step)

        # overall pruning ratios of trainable & maskable variables
        pr_trainable = calc_prune_ratio(self.vars_prnd['trainable'])
        pr_maskable = calc_prune_ratio(self.vars_prnd['maskable'])
        tf.summary.scalar('pr_trainable', pr_trainable)
        tf.summary.scalar('pr_maskable', pr_maskable)

        # create masks and corresponding operations for channel pruning
        self.masks = []
        self.mask_deltas = []
        self.mask_init_ops = []
        self.mask_updt_ops = []
        self.prune_ops = []
        for idx, var in enumerate(self.vars_prnd['maskable']):
          name = '/'.join(var.name.split('/')[1:]).replace(':0', '_mask')
          self.masks += [tf.get_variable(name, initializer=tf.ones(var.shape), trainable=False)]
          name = '/'.join(var.name.split('/')[1:]).replace(':0', '_mask_delta')
          self.mask_deltas += [tf.placeholder(tf.float32, shape=var.shape, name=name)]
          self.mask_init_ops += [self.masks[idx].assign(tf.zeros(var.shape))]
          self.mask_updt_ops += [self.masks[idx].assign_add(self.mask_deltas[idx])]
          self.prune_ops += [var.assign(var * self.masks[idx])]

        # build extra losses for regression & discrimination
        self.reg_losses, self.dis_losses, self.idxs_layer_to_block = \
          self.__build_extra_losses(labels)
        self.dis_losses += [loss_bsc]  # append discrimination-aware loss for the last block
        self.nb_layers = len(self.reg_losses)
        self.nb_blocks = len(self.dis_losses)
        for idx, reg_loss in enumerate(self.reg_losses):
          tf.summary.scalar('reg_loss_%d' % idx, reg_loss)
        for idx, dis_loss in enumerate(self.dis_losses):
          tf.summary.scalar('dis_loss_%d' % idx, dis_loss)

        # obtain the full list of trainable variables & update operations
        self.vars_all = tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope=self.model_scope_prnd)
        self.trainable_vars_all = tf.get_collection(
          tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.model_scope_prnd)
        self.update_ops_all = tf.get_collection(
          tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_prnd)

        # TF operations for initializing the channel-pruned model
        init_ops = []
        with tf.control_dependencies([tf.variables_initializer(self.vars_all)]):
          for var_full, var_prnd in zip(self.vars_full['all'], self.vars_prnd['all']):
            init_ops += [var_prnd.assign(var_full)]
        self.init_op = tf.group(init_ops)

        # TF operations for layer-wise, block-wise, and whole-network fine-tuning
        self.layer_train_ops, self.layer_init_opt_ops, self.grad_norms = self.__build_layer_ops()
        self.block_train_ops, self.block_init_opt_ops = self.__build_block_ops()
        self.train_op, self.init_opt_op = self.__build_network_ops(loss_fnl, lrn_rate)

      # TF operations for logging & summarizing
      self.sess_train = sess
      self.summary_op = tf.summary.merge_all()
      self.log_op = [lrn_rate, loss_fnl, pr_trainable, pr_maskable] + list(metrics.values())
      self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_msk'] + list(metrics.keys())
      if FLAGS.enbl_multi_gpu:
        self.bcast_op = mgw.broadcast_global_variables(0)
示例#14
0
  def __build_pruned_train_model(self, path=None, finetune=False): # pylint: disable=too-many-locals
    ''' build a training model from pruned model '''
    if path is None:
      path = FLAGS.save_path

    with tf.Graph().as_default():
      config = tf.ConfigProto()
      config.gpu_options.visible_device_list = str(# pylint: disable=no-member
        mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)
      self.sess_train = tf.Session(config=config)
      self.saver_train = tf.train.import_meta_graph(path + '.meta')
      self.saver_train.restore(self.sess_train, path)
      logits = tf.get_collection('logits')[0]
      train_images = tf.get_collection('train_images')[0]
      train_labels = tf.get_collection('train_labels')[0]
      mem_images = tf.get_collection('mem_images')[0]
      mem_labels = tf.get_collection('mem_labels')[0]

      self.sess_train.close()

      graph_editor.reroute_ts(train_images, mem_images)
      graph_editor.reroute_ts(train_labels, mem_labels)

      self.sess_train = tf.Session(config=config)
      self.saver_train.restore(self.sess_train, path)

      trainable_vars = self.trainable_vars
      loss, accuracy = self.calc_loss(train_labels, logits, trainable_vars)
      self.accuracy_keys = list(accuracy.keys())

      if FLAGS.enbl_dst:
        logits_dst = self.learner_dst.calc_logits(self.sess_train, train_images)
        loss += self.learner_dst.calc_loss(logits, logits_dst)

      tf.summary.scalar('loss', loss)
      for key in accuracy.keys():
        tf.summary.scalar(key, accuracy[key])
      self.summary_op = tf.summary.merge_all()

      global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32, trainable=False)
      self.global_step = global_step
      lrn_rate, self.nb_iters_train = setup_lrn_rate(
        self.global_step, self.model_name, self.dataset_name)

      if finetune and not FLAGS.cp_retrain:
        mom_optimizer = tf.train.AdamOptimizer(FLAGS.cp_lrn_rate_ft)
        self.log_op = [tf.constant(FLAGS.cp_lrn_rate_ft), loss, list(accuracy.values())]
      else:
        mom_optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
        self.log_op = [lrn_rate, loss, list(accuracy.values())]

      if FLAGS.enbl_multi_gpu:
        optimizer = mgw.DistributedOptimizer(mom_optimizer)
      else:
        optimizer = mom_optimizer
      grads_origin = optimizer.compute_gradients(loss, trainable_vars)
      grads_pruned, masks = self.__calc_grads_pruned(grads_origin)


      with tf.control_dependencies(self.update_ops):
        self.train_op = optimizer.apply_gradients(grads_pruned, global_step=global_step)

      self.sm_writer.add_graph(tf.get_default_graph())
      self.train_init_op = \
        tf.initialize_variables(mom_optimizer.variables() + [global_step] + masks)

      if FLAGS.enbl_multi_gpu:
        self.bcast_op = mgw.broadcast_global_variables(0)
示例#15
0
    def __build_train(self):  # pylint: disable=too-many-locals
        """Build the training graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            if FLAGS.enbl_multi_gpu:
                config.gpu_options.visible_device_list = str(mgw.local_rank())  # pylint: disable=no-member
            else:
                config.gpu_options.visible_device_list = '0'  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(sess, images)

            # model definition - weight-sparsified model
            with tf.variable_scope(self.model_scope):
                # loss & extra evaluation metrics
                logits = self.forward_train(images)
                self.maskable_var_names = [
                    var.name for var in self.maskable_vars
                ]
                loss, metrics = self.calc_loss(labels, logits,
                                               self.trainable_vars)
                if FLAGS.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits, logits_dst)
                tf.summary.scalar('loss', loss)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # learning rate schedule
                self.global_step = tf.train.get_or_create_global_step()
                lrn_rate, self.nb_iters_train = setup_lrn_rate(
                    self.global_step, self.model_name, self.dataset_name)

                # overall pruning ratios of trainable & maskable variables
                pr_trainable = calc_prune_ratio(self.trainable_vars)
                pr_maskable = calc_prune_ratio(self.maskable_vars)
                tf.summary.scalar('pr_trainable', pr_trainable)
                tf.summary.scalar('pr_maskable', pr_maskable)

                # build masks and corresponding operations for weight sparsification
                self.masks, self.prune_op = self.__build_masks()

                # optimizer & gradients
                optimizer_base = tf.train.MomentumOptimizer(
                    lrn_rate, FLAGS.momentum)
                if not FLAGS.enbl_multi_gpu:
                    optimizer = optimizer_base
                else:
                    optimizer = mgw.DistributedOptimizer(optimizer_base)
                grads_origin = optimizer.compute_gradients(
                    loss, self.trainable_vars)
                grads_pruned = self.__calc_grads_pruned(grads_origin)

            # TF operations & model saver
            self.sess_train = sess
            with tf.control_dependencies(self.update_ops):
                self.train_op = optimizer.apply_gradients(
                    grads_pruned, global_step=self.global_step)
            self.summary_op = tf.summary.merge_all()
            self.log_op = [lrn_rate, loss, pr_trainable, pr_maskable] + list(
                metrics.values())
            self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_msk'] + list(
                metrics.keys())
            self.init_op = tf.variables_initializer(self.vars)
            self.init_opt_op = tf.variables_initializer(
                optimizer_base.variables())
            if FLAGS.enbl_multi_gpu:
                self.bcast_op = mgw.broadcast_global_variables(0)
            self.saver_train = tf.train.Saver(self.vars)
示例#16
0
    def __build_train(self):
        with tf.Graph().as_default():
            # TensorFlow session
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(mgw.local_rank() \
                if FLAGS.enbl_multi_gpu else 0)
            self.sess_train = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()
                images.set_shape((FLAGS.batch_size, images.shape[1], images.shape[2], \
                    images.shape[3]))

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(
                    self.sess_train, images)

            # model definition
            with tf.variable_scope(self.model_scope, reuse=tf.AUTO_REUSE):
                # forward pass
                logits = self.forward_train(images)

                self.saver_train = tf.train.Saver(self.vars)

                self.weights = [
                    v for v in self.trainable_vars
                    if 'kernel' in v.name or 'weight' in v.name
                ]
                if not FLAGS.nuql_quantize_all_layers:
                    self.weights = self.weights[1:-1]
                self.statistics['num_weights'] = \
                    [tf.reshape(v, [-1]).shape[0].value for v in self.weights]

                self.__quantize_train_graph()

                # loss & accuracy
                # Strictly speaking, clusters should be not included for regularization.
                loss, metrics = self.calc_loss(labels, logits,
                                               self.trainable_vars)
                if self.dataset_name == 'cifar_10':
                    acc_top1, acc_top5 = metrics['accuracy'], tf.constant(0.)
                elif self.dataset_name == 'ilsvrc_12':
                    acc_top1, acc_top5 = metrics['acc_top1'], metrics[
                        'acc_top5']
                else:
                    raise ValueError("Unrecognized dataset name")

                model_loss = loss
                if FLAGS.enbl_dst:
                    dst_loss = self.helper_dst.calc_loss(logits, logits_dst)
                    loss += dst_loss
                    tf.summary.scalar('dst_loss', dst_loss)
                tf.summary.scalar('model_loss', model_loss)
                tf.summary.scalar('loss', loss)
                tf.summary.scalar('acc_top1', acc_top1)
                tf.summary.scalar('acc_top5', acc_top5)

                self.ft_step = tf.get_variable('finetune_step',
                                               shape=[],
                                               dtype=tf.int32,
                                               trainable=False)

            # optimizer & gradients
            init_lr, bnds, decay_rates, self.finetune_steps = \
                setup_bnds_decay_rates(self.model_name, self.dataset_name)
            lrn_rate = tf.train.piecewise_constant(self.ft_step, [i for i in bnds], \
                [init_lr * decay_rate for decay_rate in decay_rates])

            # optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
            optimizer = tf.train.AdamOptimizer(lrn_rate)
            if FLAGS.enbl_multi_gpu:
                optimizer = mgw.DistributedOptimizer(optimizer)

            # acquire non-cluster-vars and clusters.
            clusters = [v for v in self.trainable_vars if 'clusters' in v.name]
            rest_trainable_vars = [v for v in self.trainable_vars \
                if v not in clusters]

            # determine the var_list optimize
            if FLAGS.nuql_opt_mode in ['cluster', 'both']:
                if FLAGS.nuql_opt_mode == 'both':
                    optimizable_vars = self.trainable_vars
                else:
                    optimizable_vars = clusters
                if FLAGS.nuql_enbl_rl_agent:
                    optimizer_fintune = tf.train.GradientDescentOptimizer(
                        lrn_rate)
                    if FLAGS.enbl_multi_gpu:
                        optimizer_fintune = mgw.DistributedOptimizer(
                            optimizer_fintune)
                    grads_fintune = optimizer_fintune.compute_gradients(
                        loss, var_list=optimizable_vars)

            elif FLAGS.nuql_opt_mode == 'weights':
                optimizable_vars = rest_trainable_vars
            else:
                raise ValueError("Unknown optimization mode")

            grads = optimizer.compute_gradients(loss,
                                                var_list=optimizable_vars)

            # sm write graph
            self.sm_writer.add_graph(self.sess_train.graph)

            # define the ops
            with tf.control_dependencies(self.update_ops):
                self.ops['train'] = optimizer.apply_gradients(
                    grads, global_step=self.ft_step)
                if FLAGS.nuql_opt_mode in ['both', 'cluster'
                                           ] and FLAGS.nuql_enbl_rl_agent:
                    self.ops['rl_fintune'] = optimizer_fintune.apply_gradients(
                        grads_fintune, global_step=self.ft_step)
                else:
                    self.ops['rl_fintune'] = self.ops['train']
            self.ops['summary'] = tf.summary.merge_all()
            if FLAGS.enbl_dst:
                self.ops['log'] = [
                    lrn_rate, dst_loss, model_loss, loss, acc_top1, acc_top5
                ]
            else:
                self.ops['log'] = [
                    lrn_rate, model_loss, loss, acc_top1, acc_top5
                ]

            # NOTE: first run non_cluster_init_op, then cluster_init_op
            cluster_global_vars = [
                v for v in self.vars if 'clusters' in v.name
            ]
            # pay attention to beta1_power, beta2_power.
            non_cluster_global_vars = [
                v for v in tf.global_variables()
                if v not in cluster_global_vars
            ]

            self.ops['non_cluster_init'] = tf.variables_initializer(
                var_list=non_cluster_global_vars)
            self.ops['cluster_init'] = tf.variables_initializer(
                var_list=cluster_global_vars)
            self.ops['bcast'] = mgw.broadcast_global_variables(0) \
                if FLAGS.enbl_multi_gpu else None
            self.ops['reset_ft_step'] = tf.assign(self.ft_step, \
                tf.constant(0, dtype=tf.int32))

            self.saver_quant = tf.train.Saver(self.vars)
示例#17
0
    def __build_train(self):  # pylint: disable=too-many-locals,too-many-statements
        """Build the training graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True  # pylint: disable=no-member
            config.gpu_options.visible_device_list = \
              str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(sess, images)

            # model definition - full model
            with tf.variable_scope(self.model_scope_full):
                __ = self.forward_train(images)
                self.vars_full = get_vars_by_scope(self.model_scope_full)
                self.saver_full = tf.train.Saver(self.vars_full['all'])
                self.save_path_full = FLAGS.save_path

            # model definition - channel-pruned model
            with tf.variable_scope(self.model_scope_prnd):
                logits_prnd = self.forward_train(images)
                self.vars_prnd = get_vars_by_scope(self.model_scope_prnd)
                self.conv_krnl_var_names = [
                    var.name for var in self.vars_prnd['conv_krnl']
                ]
                self.global_step = tf.train.get_or_create_global_step()
                self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all'] +
                                                       [self.global_step])

                # loss & extra evaluation metrics
                loss, metrics = self.calc_loss(labels, logits_prnd,
                                               self.vars_prnd['trainable'])
                if FLAGS.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits_prnd, logits_dst)
                tf.summary.scalar('loss', loss)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # learning rate schedule
                lrn_rate, self.nb_iters_train = self.setup_lrn_rate(
                    self.global_step)

                # calculate pruning ratios
                pr_trainable = calc_prune_ratio(self.vars_prnd['trainable'])
                pr_conv_krnl = calc_prune_ratio(self.vars_prnd['conv_krnl'])
                tf.summary.scalar('pr_trainable', pr_trainable)
                tf.summary.scalar('pr_conv_krnl', pr_conv_krnl)

                # create masks and corresponding operations for channel pruning
                self.masks = []
                mask_updt_ops = [
                ]  # update the mask based on convolutional kernel's value
                for idx, var in enumerate(self.vars_prnd['conv_krnl']):
                    tf.logging.info(
                        'creating a pruning mask for {} of size {}'.format(
                            var.name, var.shape))
                    mask_name = '/'.join(var.name.split('/')[1:]).replace(
                        ':0', '_mask')
                    mask_shape = [1, 1, var.shape[2], 1]  # 1 x 1 x c_in x 1
                    mask = tf.get_variable(mask_name,
                                           initializer=tf.ones(mask_shape),
                                           trainable=False)
                    var_norm = tf.reduce_sum(tf.square(var),
                                             axis=[0, 1, 3],
                                             keepdims=True)
                    self.masks += [mask]
                    mask_updt_ops += [
                        mask.assign(tf.cast(var_norm > 0.0, tf.float32))
                    ]
                self.mask_updt_op = tf.group(mask_updt_ops)

                # build operations for channel selection
                self.__build_chn_select_ops()

                # optimizer & gradients
                optimizer_base = tf.train.MomentumOptimizer(
                    lrn_rate, FLAGS.momentum)
                if not FLAGS.enbl_multi_gpu:
                    optimizer = optimizer_base
                else:
                    optimizer = mgw.DistributedOptimizer(optimizer_base)
                grads_origin = optimizer.compute_gradients(
                    loss, self.vars_prnd['trainable'])
                grads_pruned = self.__calc_grads_pruned(grads_origin)
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                               scope=self.model_scope_prnd)
                with tf.control_dependencies(update_ops):
                    self.train_op = optimizer.apply_gradients(
                        grads_pruned, global_step=self.global_step)

                # TF operations for initializing the channel-pruned model
                init_ops = []
                for var_full, var_prnd in zip(self.vars_full['all'],
                                              self.vars_prnd['all']):
                    init_ops += [var_prnd.assign(var_full)]
                init_ops += [self.global_step.initializer
                             ]  # initialize the global step
                init_ops += [
                    tf.variables_initializer(optimizer_base.variables())
                ]
                self.init_op = tf.group(init_ops)

            # TF operations for logging & summarizing
            self.sess_train = sess
            self.summary_op = tf.summary.merge_all()
            self.log_op = [lrn_rate, loss, pr_trainable, pr_conv_krnl] + list(
                metrics.values())
            self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_krn'] + list(
                metrics.keys())
            if FLAGS.enbl_multi_gpu:
                self.bcast_op = mgw.broadcast_global_variables(0)
示例#18
0
  def __init__(self,
               dataset_name,
               weights,
               statistics,
               bit_placeholders,
               ops,
               layerwise_tune_list,
               sess_train,
               sess_eval,
               saver_train,
               saver_eval,
               barrier_fn):
    """ By passing the ops in the learner, we do not need to build the graph
    again for training and testing.

    Args:
    * dataset_name: a string that indicates which dataset to use
    * weights: a list of Tensors, the weights of networks to quantize
    * statistics: a dict, recording the number of weights, activations e.t.c.
    * bit_placeholders: a dict of placeholder Tensors, the input of bits
    * ops: a dict of ops, including trian_op, eval_op e.t.c.
    * layerwise_tune_list: a tuple, in which [0] records the layerwise op and
                          [1] records the layerwise l2_norm
    * sess_train: a session for train
    * sess_eval: a session for eval
    * saver_train: a Tensorflow Saver for the training graph
    * saver_eval: a Tensorflow Saver for the eval graph
    * barrier_fn: a function that implements barrier
    """
    self.dataset_name = dataset_name
    self.weights = weights
    self.statistics = statistics
    self.bit_placeholders = bit_placeholders
    self.ops = ops
    self.layerwise_tune_ops, self.layerwise_diff = \
        layerwise_tune_list[0], layerwise_tune_list[1]
    self.sess_train = sess_train
    self.sess_eval = sess_eval
    self.saver_train = saver_train
    self.saver_eval = saver_eval
    self.auto_barrier = barrier_fn

    self.total_num_weights = sum(self.statistics['num_weights'])
    self.total_bits = self.total_num_weights * FLAGS.uql_equivalent_bits

    self.w_rl_helper = RLHelper(self.sess_train,
                                self.total_bits,
                                self.statistics['num_weights'],
                                self.weights,
                                random_layers=FLAGS.uql_enbl_random_layers)

    self.mgw_size = int(mgw.size()) if FLAGS.enbl_multi_gpu else 1
    self.tune_global_steps = int(FLAGS.uql_tune_global_steps / self.mgw_size)
    self.tune_global_disp_steps = int(FLAGS.uql_tune_disp_steps / self.mgw_size)

    # build the rl trianing graph
    with tf.Graph().as_default():
      config = tf.ConfigProto()
      config.gpu_options.visible_device_list = str(mgw.local_rank() \
          if FLAGS.enbl_multi_gpu else 0)
      self.sess_rl = tf.Session(config=config)

      # train an RL agent through multiple roll-outs
      self.s_dims = self.w_rl_helper.s_dims
      self.a_dims = 1
      buff_size = len(self.weights) * int(FLAGS.uql_nb_rlouts // 4)
      self.agent = DdpgAgent(self.sess_rl,
                             self.s_dims,
                             self.a_dims,
                             FLAGS.uql_nb_rlouts,
                             buff_size,
                             a_min=0.,
                             a_max=FLAGS.uql_w_bit_max-FLAGS.uql_w_bit_min)
    def __build_train(self, model_helper):  # pylint: disable=too-many-locals
        """Build the training graph for the 'optimal' protocol.

    Args:
    * model_helper: model helper with definitions of model & dataset
    """

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator, __ = model_helper.build_dataset_train(
                    enbl_trn_val_split=True)
                images, labels = iterator.get_next()

            # model definition - full-precision network
            with tf.variable_scope(self.model_scope_full):
                logits = model_helper.forward_eval(
                    images)  # DO NOT USE forward_train() HERE!!!
                self.vars_full = get_vars_by_scope(self.model_scope_full)
                self.saver_full = tf.train.Saver(self.vars_full['all'])
                self.save_path_full = FLAGS.save_path

            # model definition - weight sparsified network
            with tf.variable_scope(self.model_scope_prnd):
                # forward pass & variables' saver
                logits = model_helper.forward_eval(
                    images)  # DO NOT USE forward_train() HERE!!!
                self.vars_prnd = get_vars_by_scope(self.model_scope_prnd)
                self.maskable_var_names = [
                    var.name for var in self.vars_prnd['maskable']
                ]
                loss, __ = model_helper.calc_loss(labels, logits,
                                                  self.vars_prnd['trainable'])
                self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all'])
                self.save_path_prnd = FLAGS.save_path.replace(
                    'models', 'models_pruned')

                # build masks for variable pruning
                self.masks, self.pr_all, self.pr_assign_op = self.__build_masks(
                )

            # create operations for initializing the weight sparsified network
            init_ops = []
            for var_full, var_prnd in zip(self.vars_full['all'],
                                          self.vars_prnd['all']):
                if var_full not in self.vars_full['maskable']:
                    init_ops += [var_prnd.assign(var_full)]
                else:
                    idx = self.vars_full['maskable'].index(var_full)
                    init_ops += [var_prnd.assign(var_full * self.masks[idx])]
            self.init_op = tf.group(init_ops)

            # build operations for layerwise regression & network fine-tuning
            self.rg_init_op, self.rg_train_ops = self.__build_layer_rg_ops()
            self.ft_init_op, self.ft_train_op = self.__build_network_ft_ops(
                loss)
            if FLAGS.enbl_multi_gpu:
                self.bcast_op = mgw.broadcast_global_variables(0)

            # create RL helper & agent on the primary worker
            if is_primary_worker('global'):
                self.rl_helper, self.agent = self.__build_rl_helper_n_agent(
                    sess)

            # TF operations
            self.sess_train = sess
示例#20
0
    def build_graph(self, is_train):
        with tf.Graph().as_default():
            # TensorFlow session
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.dataset_train.build(
                ) if is_train else self.dataset_eval.build()
                images, labels = iterator.get_next()
                if not isinstance(images, dict):
                    tf.add_to_collection('images_final', images)
                else:
                    tf.add_to_collection('images_final', images['image'])

            # model definition - primary model
            with tf.variable_scope(self.model_scope):
                # forward pass
                logits = self.hrnet.forward_train(
                    images) if is_train else self.hrnet.forward_eval(images)
                if not isinstance(logits, dict):
                    tf.add_to_collection('logits_final', logits)
                else:
                    for value in logits.values():
                        tf.add_to_collection('logits_final', value)

                # loss & extra evaluation metrics
                loss, metrics = self.hrnet.calc_loss(labels, logits,
                                                     self.trainable_vars)

                tf.summary.scalar('loss', loss)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # optimizer & gradients
                if is_train:
                    self.global_step = tf.train.get_or_create_global_step()
                    lrn_rate, self.nb_iters_train = self.setup_lrn_rate(
                        self.global_step)

                    optimizer = tf.train.MomentumOptimizer(
                        lrn_rate, self.hrnet.cfg['COMMON']['momentum'])
                    if FLAGS.enbl_multi_gpu:
                        optimizer = mgw.DistributedOptimizer(optimizer)
                    grads = optimizer.compute_gradients(
                        loss, self.trainable_vars)

            # TF operations & model saver
            if is_train:
                self.sess_train = sess

                with tf.control_dependencies(self.update_ops):
                    self.train_op = optimizer.apply_gradients(
                        grads, global_step=self.global_step)

                self.summary_op = tf.summary.merge_all()
                self.sm_writer = tf.summary.FileWriter(logdir=self.log_path)
                self.log_op = [lrn_rate, loss] + list(metrics.values())
                self.log_op_names = ['lr', 'loss'] + list(metrics.keys())
                self.init_op = tf.variables_initializer(self.vars)
                if FLAGS.enbl_multi_gpu:
                    self.bcast_op = mgw.broadcast_global_variables(0)
                self.saver_train = tf.train.Saver(self.vars)
            else:
                self.sess_eval = sess
                self.eval_op = [loss] + list(metrics.values())
                self.eval_op_names = ['loss'] + list(metrics.keys())
                self.saver_eval = tf.train.Saver(self.vars)
示例#21
0
    def __build_train(self):  # pylint: disable=too-many-locals,too-many-statements
        """Build the training graph."""

        with tf.Graph().as_default() as graph:
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            config.gpu_options.allow_growth = True  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()

            # model definition - uniform quantized model - part 1
            with tf.variable_scope(self.model_scope_quan):
                logits_quan = self.forward_train(images)
                if not isinstance(logits_quan, dict):
                    outputs = tf.nn.softmax(logits_quan)
                else:
                    outputs = tf.nn.softmax(logits_quan['cls_pred'])
                tf.contrib.quantize.experimental_create_training_graph(
                    weight_bits=FLAGS.uqtf_weight_bits,
                    activation_bits=FLAGS.uqtf_activation_bits,
                    quant_delay=FLAGS.uqtf_quant_delay,
                    freeze_bn_delay=FLAGS.uqtf_freeze_bn_delay,
                    scope=self.model_scope_quan)
                for node_name in self.unquant_node_names:
                    insert_quant_op(graph, node_name, is_train=True)
                self.vars_quan = get_vars_by_scope(self.model_scope_quan)
                self.global_step = tf.train.get_or_create_global_step()
                self.saver_quan_train = tf.train.Saver(self.vars_quan['all'] +
                                                       [self.global_step])

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(sess, images)

            # model definition - full model
            with tf.variable_scope(self.model_scope_full):
                __ = self.forward_train(images)
                self.vars_full = get_vars_by_scope(self.model_scope_full)
                self.saver_full = tf.train.Saver(self.vars_full['all'])
                self.save_path_full = FLAGS.save_path

            # model definition - uniform quantized model - part 2
            with tf.variable_scope(self.model_scope_quan):
                # loss & extra evaluation metrics
                loss_bsc, metrics = self.calc_loss(labels, logits_quan,
                                                   self.vars_quan['trainable'])
                if not FLAGS.enbl_dst:
                    loss_fnl = loss_bsc
                else:
                    loss_fnl = loss_bsc + self.helper_dst.calc_loss(
                        logits_quan, logits_dst)
                tf.summary.scalar('loss_bsc', loss_bsc)
                tf.summary.scalar('loss_fnl', loss_fnl)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # learning rate schedule
                lrn_rate, self.nb_iters_train = self.setup_lrn_rate(
                    self.global_step)
                lrn_rate *= FLAGS.uqtf_lrn_rate_dcy

                # decrease the learning rate by a constant factor
                # if self.dataset_name == 'cifar_10':
                #  lrn_rate *= 1e-3
                # elif self.dataset_name == 'ilsvrc_12':
                #  lrn_rate *= 1e-4
                # else:
                #  raise ValueError('unrecognized dataset\'s name: ' + self.dataset_name)

                # obtain the full list of trainable variables & update operations
                self.vars_all = tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES, scope=self.model_scope_quan)
                self.trainable_vars_all = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES,
                    scope=self.model_scope_quan)
                self.update_ops_all = tf.get_collection(
                    tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_quan)

                # TF operations for initializing the uniform quantized model
                init_ops = []
                with tf.control_dependencies(
                    [tf.variables_initializer(self.vars_all)]):
                    for var_full, var_quan in zip(self.vars_full['all'],
                                                  self.vars_quan['all']):
                        init_ops += [var_quan.assign(var_full)]
                init_ops += [self.global_step.initializer]
                self.init_op = tf.group(init_ops)

                # TF operations for fine-tuning
                # optimizer_base = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
                optimizer_base = tf.train.AdamOptimizer(lrn_rate)
                if not FLAGS.enbl_multi_gpu:
                    optimizer = optimizer_base
                else:
                    optimizer = mgw.DistributedOptimizer(optimizer_base)
                grads = optimizer.compute_gradients(loss_fnl,
                                                    self.trainable_vars_all)
                with tf.control_dependencies(self.update_ops_all):
                    self.train_op = optimizer.apply_gradients(
                        grads, global_step=self.global_step)
                self.init_opt_op = tf.variables_initializer(
                    optimizer_base.variables())

            # TF operations for logging & summarizing
            self.sess_train = sess
            self.summary_op = tf.summary.merge_all()
            self.log_op = [lrn_rate, loss_fnl] + list(metrics.values())
            self.log_op_names = ['lr', 'loss'] + list(metrics.keys())
            if FLAGS.enbl_multi_gpu:
                self.bcast_op = mgw.broadcast_global_variables(0)
示例#22
0
    def __build_prune(self):
        """Build the channel pruning graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True  # pylint: disable=no-member
            config.gpu_options.visible_device_list = \
                str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()
                if not isinstance(images, dict):
                    images_ph = tf.placeholder(tf.float32,
                                               shape=images.shape,
                                               name='images_ph')
                else:
                    images_ph = {}
                    for key, value in images.items():
                        images_ph[key] = tf.placeholder(value.dtype,
                                                        shape=value.shape,
                                                        name=(key + '_ph'))

            # restore a pre-trained model as full model
            with tf.variable_scope(self.model_scope_full):
                __ = self.forward_train(images_ph)
                vars_full = get_vars_by_scope(self.model_scope_full)
                saver_full = tf.train.Saver(vars_full['all'])
                saver_full.restore(
                    sess,
                    tf.train.latest_checkpoint(os.path.dirname(
                        FLAGS.save_path)))

            # restore a pre-trained model as channel-pruned model
            with tf.variable_scope(self.model_scope_prnd):
                logits_prnd = self.forward_train(images_ph)
                vars_prnd = get_vars_by_scope(self.model_scope_prnd)
                global_step = tf.train.get_or_create_global_step()
                saver_prnd = tf.train.Saver(vars_prnd['all'] + [global_step])

                # loss & extra evaluation metrics
                loss, metrics = self.calc_loss(labels, logits_prnd,
                                               vars_prnd['trainable'])

                # calculate pruning ratios
                pr_trainable = calc_prune_ratio(vars_prnd['trainable'])
                pr_conv_krnl = calc_prune_ratio(vars_prnd['conv_krnl'])

                # use full model's weights to initialize channel-pruned model
                init_ops = [global_step.initializer]
                for var_full, var_prnd in zip(vars_full['all'],
                                              vars_prnd['all']):
                    init_ops += [var_prnd.assign(var_full)]
                self.init_op_prune = tf.group(init_ops)

            # build a list of Conv2D operation's information
            self.conv_info_list = self.__build_conv_info_list(
                vars_prnd['conv_krnl'])

            # build meta LASSO/least-square optimization problems
            self.meta_lasso = self.__build_meta_lasso()
            self.meta_lstsq = self.__build_meta_lstsq()

            # TF operations for logging & summarizing
            self.sess_prune = sess
            self.images_prune = images
            self.images_prune_ph = images_ph
            self.saver_prune = saver_prnd
            self.pr_trn_prune = pr_trainable
            self.pr_krn_prune = pr_conv_krnl