示例#1
0
    def __build_train(self):  # pylint: disable=too-many-locals,too-many-statements
        """Build the training graph."""

        with tf.Graph().as_default() as graph:
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            config.gpu_options.allow_growth = True  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()

            # model definition - uniform quantized model - part 1
            with tf.variable_scope(self.model_scope_quan):
                logits_quan = self.forward_train(images)
                if not isinstance(logits_quan, dict):
                    outputs = tf.nn.softmax(logits_quan)
                else:
                    outputs = tf.nn.softmax(logits_quan['cls_pred'])
                tf.contrib.quantize.experimental_create_training_graph(
                    weight_bits=FLAGS.uqtf_weight_bits,
                    activation_bits=FLAGS.uqtf_activation_bits,
                    quant_delay=FLAGS.uqtf_quant_delay,
                    freeze_bn_delay=FLAGS.uqtf_freeze_bn_delay,
                    scope=self.model_scope_quan)
                for node_name in self.unquant_node_names:
                    insert_quant_op(graph, node_name, is_train=True)
                self.vars_quan = get_vars_by_scope(self.model_scope_quan)
                self.global_step = tf.train.get_or_create_global_step()
                self.saver_quan_train = tf.train.Saver(self.vars_quan['all'] +
                                                       [self.global_step])

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(sess, images)

            # model definition - full model
            with tf.variable_scope(self.model_scope_full):
                __ = self.forward_train(images)
                self.vars_full = get_vars_by_scope(self.model_scope_full)
                self.saver_full = tf.train.Saver(self.vars_full['all'])
                self.save_path_full = FLAGS.save_path

            # model definition - uniform quantized model - part 2
            with tf.variable_scope(self.model_scope_quan):
                # loss & extra evaluation metrics
                loss_bsc, metrics = self.calc_loss(labels, logits_quan,
                                                   self.vars_quan['trainable'])
                if not FLAGS.enbl_dst:
                    loss_fnl = loss_bsc
                else:
                    loss_fnl = loss_bsc + self.helper_dst.calc_loss(
                        logits_quan, logits_dst)
                tf.summary.scalar('loss_bsc', loss_bsc)
                tf.summary.scalar('loss_fnl', loss_fnl)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # learning rate schedule
                lrn_rate, self.nb_iters_train = self.setup_lrn_rate(
                    self.global_step)
                lrn_rate *= FLAGS.uqtf_lrn_rate_dcy

                # decrease the learning rate by a constant factor
                # if self.dataset_name == 'cifar_10':
                #  lrn_rate *= 1e-3
                # elif self.dataset_name == 'ilsvrc_12':
                #  lrn_rate *= 1e-4
                # else:
                #  raise ValueError('unrecognized dataset\'s name: ' + self.dataset_name)

                # obtain the full list of trainable variables & update operations
                self.vars_all = tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES, scope=self.model_scope_quan)
                self.trainable_vars_all = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES,
                    scope=self.model_scope_quan)
                self.update_ops_all = tf.get_collection(
                    tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_quan)

                # TF operations for initializing the uniform quantized model
                init_ops = []
                with tf.control_dependencies(
                    [tf.variables_initializer(self.vars_all)]):
                    for var_full, var_quan in zip(self.vars_full['all'],
                                                  self.vars_quan['all']):
                        init_ops += [var_quan.assign(var_full)]
                init_ops += [self.global_step.initializer]
                self.init_op = tf.group(init_ops)

                # TF operations for fine-tuning
                # optimizer_base = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
                optimizer_base = tf.train.AdamOptimizer(lrn_rate)
                if not FLAGS.enbl_multi_gpu:
                    optimizer = optimizer_base
                else:
                    optimizer = mgw.DistributedOptimizer(optimizer_base)
                grads = optimizer.compute_gradients(loss_fnl,
                                                    self.trainable_vars_all)
                with tf.control_dependencies(self.update_ops_all):
                    self.train_op = optimizer.apply_gradients(
                        grads, global_step=self.global_step)
                self.init_opt_op = tf.variables_initializer(
                    optimizer_base.variables())

            # TF operations for logging & summarizing
            self.sess_train = sess
            self.summary_op = tf.summary.merge_all()
            self.log_op = [lrn_rate, loss_fnl] + list(metrics.values())
            self.log_op_names = ['lr', 'loss'] + list(metrics.keys())
            if FLAGS.enbl_multi_gpu:
                self.bcast_op = mgw.broadcast_global_variables(0)
示例#2
0
  def __build_train(self):  # pylint: disable=too-many-locals,too-many-statements
    """Build the training graph."""

    with tf.Graph().as_default():
      # create a TF session for the current graph
      config = tf.ConfigProto()
      config.gpu_options.visible_device_list = str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
      sess = tf.Session(config=config)

      # data input pipeline
      with tf.variable_scope(self.data_scope):
        iterator = self.build_dataset_train()
        images, labels = iterator.get_next()

      # model definition - distilled model
      if FLAGS.enbl_dst:
        logits_dst = self.helper_dst.calc_logits(sess, images)

      # model definition - full model
      with tf.variable_scope(self.model_scope_full):
        __ = self.forward_train(images)
        self.vars_full = get_vars_by_scope(self.model_scope_full)
        self.saver_full = tf.train.Saver(self.vars_full['all'])
        self.save_path_full = FLAGS.save_path

      # model definition - channel-pruned model
      with tf.variable_scope(self.model_scope_prnd):
        logits_prnd = self.forward_train(images)
        self.vars_prnd = get_vars_by_scope(self.model_scope_prnd)
        self.maskable_var_names = [var.name for var in self.vars_prnd['maskable']]
        self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all'])

        # loss & extra evaluation metrics
        loss_bsc, metrics = self.calc_loss(labels, logits_prnd, self.vars_prnd['trainable'])
        if not FLAGS.enbl_dst:
          loss_fnl = loss_bsc
        else:
          loss_fnl = loss_bsc + self.helper_dst.calc_loss(logits_prnd, logits_dst)
        tf.summary.scalar('loss_bsc', loss_bsc)
        tf.summary.scalar('loss_fnl', loss_fnl)
        for key, value in metrics.items():
          tf.summary.scalar(key, value)

        # learning rate schedule
        self.global_step = tf.train.get_or_create_global_step()
        lrn_rate, self.nb_iters_train = self.setup_lrn_rate(self.global_step)

        # overall pruning ratios of trainable & maskable variables
        pr_trainable = calc_prune_ratio(self.vars_prnd['trainable'])
        pr_maskable = calc_prune_ratio(self.vars_prnd['maskable'])
        tf.summary.scalar('pr_trainable', pr_trainable)
        tf.summary.scalar('pr_maskable', pr_maskable)

        # create masks and corresponding operations for channel pruning
        self.masks = []
        self.mask_deltas = []
        self.mask_init_ops = []
        self.mask_updt_ops = []
        self.prune_ops = []
        for idx, var in enumerate(self.vars_prnd['maskable']):
          name = '/'.join(var.name.split('/')[1:]).replace(':0', '_mask')
          self.masks += [tf.get_variable(name, initializer=tf.ones(var.shape), trainable=False)]
          name = '/'.join(var.name.split('/')[1:]).replace(':0', '_mask_delta')
          self.mask_deltas += [tf.placeholder(tf.float32, shape=var.shape, name=name)]
          self.mask_init_ops += [self.masks[idx].assign(tf.zeros(var.shape))]
          self.mask_updt_ops += [self.masks[idx].assign_add(self.mask_deltas[idx])]
          self.prune_ops += [var.assign(var * self.masks[idx])]

        # build extra losses for regression & discrimination
        self.reg_losses, self.dis_losses, self.idxs_layer_to_block = \
          self.__build_extra_losses(labels)
        self.dis_losses += [loss_bsc]  # append discrimination-aware loss for the last block
        self.nb_layers = len(self.reg_losses)
        self.nb_blocks = len(self.dis_losses)
        for idx, reg_loss in enumerate(self.reg_losses):
          tf.summary.scalar('reg_loss_%d' % idx, reg_loss)
        for idx, dis_loss in enumerate(self.dis_losses):
          tf.summary.scalar('dis_loss_%d' % idx, dis_loss)

        # obtain the full list of trainable variables & update operations
        self.vars_all = tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope=self.model_scope_prnd)
        self.trainable_vars_all = tf.get_collection(
          tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.model_scope_prnd)
        self.update_ops_all = tf.get_collection(
          tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_prnd)

        # TF operations for initializing the channel-pruned model
        init_ops = []
        with tf.control_dependencies([tf.variables_initializer(self.vars_all)]):
          for var_full, var_prnd in zip(self.vars_full['all'], self.vars_prnd['all']):
            init_ops += [var_prnd.assign(var_full)]
        self.init_op = tf.group(init_ops)

        # TF operations for layer-wise, block-wise, and whole-network fine-tuning
        self.layer_train_ops, self.layer_init_opt_ops, self.grad_norms = self.__build_layer_ops()
        self.block_train_ops, self.block_init_opt_ops = self.__build_block_ops()
        self.train_op, self.init_opt_op = self.__build_network_ops(loss_fnl, lrn_rate)

      # TF operations for logging & summarizing
      self.sess_train = sess
      self.summary_op = tf.summary.merge_all()
      self.log_op = [lrn_rate, loss_fnl, pr_trainable, pr_maskable] + list(metrics.values())
      self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_msk'] + list(metrics.keys())
      if FLAGS.enbl_multi_gpu:
        self.bcast_op = mgw.broadcast_global_variables(0)
示例#3
0
  def __build(self, is_train):  # pylint: disable=too-many-locals
    """Build the training / evaluation graph.

    Args:
    * is_train: whether to create the training graph
    """

    with tf.Graph().as_default():
      # TensorFlow session
      config = tf.ConfigProto()
      config.gpu_options.visible_device_list = str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
      sess = tf.Session(config=config)

      # data input pipeline
      with tf.variable_scope(self.data_scope):
        iterator = self.build_dataset_train() if is_train else self.build_dataset_eval()
        images, labels = iterator.get_next()
        tf.add_to_collection('images_final', images)

      # model definition - distilled model
      if self.enbl_dst:
        logits_dst = self.helper_dst.calc_logits(sess, images)

      # model definition - primary model
      with tf.variable_scope(self.model_scope):
        # forward pass
        logits = self.forward_train(images) if is_train else self.forward_eval(images)
        tf.add_to_collection('logits_final', logits)

        # loss & extra evalution metrics
        loss, metrics = self.calc_loss(labels, logits, self.trainable_vars)
        if self.enbl_dst:
          loss += self.helper_dst.calc_loss(logits, logits_dst)
        tf.summary.scalar('loss', loss)
        for key, value in metrics.items():
          tf.summary.scalar(key, value)

        # optimizer & gradients
        if is_train:
          self.global_step = tf.train.get_or_create_global_step()
          lrn_rate, self.nb_iters_train = self.setup_lrn_rate(self.global_step)
          optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
          if FLAGS.enbl_multi_gpu:
            optimizer = mgw.DistributedOptimizer(optimizer)
          grads = optimizer.compute_gradients(loss, self.trainable_vars)

      # TF operations & model saver
      if is_train:
        self.sess_train = sess
        with tf.control_dependencies(self.update_ops):
          self.train_op = optimizer.apply_gradients(grads, global_step=self.global_step)
        self.summary_op = tf.summary.merge_all()
        self.log_op = [lrn_rate, loss] + list(metrics.values())
        self.log_op_names = ['lr', 'loss'] + list(metrics.keys())
        self.init_op = tf.variables_initializer(self.vars)
        if FLAGS.enbl_multi_gpu:
          self.bcast_op = mgw.broadcast_global_variables(0)
        self.saver_train = tf.train.Saver(self.vars)
      else:
        self.sess_eval = sess
        self.eval_op = [loss] + list(metrics.values())
        self.eval_op_names = ['loss'] + list(metrics.keys())
        self.outputs_eval = logits
        self.saver_eval = tf.train.Saver(self.vars)
示例#4
0
    def __build_train(self):
        with tf.Graph().as_default():
            # TensorFlow session
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(mgw.local_rank() \
                if FLAGS.enbl_multi_gpu else 0)
            self.sess_train = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()
                images.set_shape((FLAGS.batch_size, images.shape[1], images.shape[2], \
                    images.shape[3]))

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(
                    self.sess_train, images)

            # model definition
            with tf.variable_scope(self.model_scope, reuse=tf.AUTO_REUSE):
                # forward pass
                logits = self.forward_train(images)

                self.saver_train = tf.train.Saver(self.vars)

                self.weights = [
                    v for v in self.trainable_vars
                    if 'kernel' in v.name or 'weight' in v.name
                ]
                if not FLAGS.nuql_quantize_all_layers:
                    self.weights = self.weights[1:-1]
                self.statistics['num_weights'] = \
                    [tf.reshape(v, [-1]).shape[0].value for v in self.weights]

                self.__quantize_train_graph()

                # loss & accuracy
                # Strictly speaking, clusters should be not included for regularization.
                loss, metrics = self.calc_loss(labels, logits,
                                               self.trainable_vars)
                if self.dataset_name == 'cifar_10':
                    acc_top1, acc_top5 = metrics['accuracy'], tf.constant(0.)
                elif self.dataset_name == 'ilsvrc_12':
                    acc_top1, acc_top5 = metrics['acc_top1'], metrics[
                        'acc_top5']
                else:
                    raise ValueError("Unrecognized dataset name")

                model_loss = loss
                if FLAGS.enbl_dst:
                    dst_loss = self.helper_dst.calc_loss(logits, logits_dst)
                    loss += dst_loss
                    tf.summary.scalar('dst_loss', dst_loss)
                tf.summary.scalar('model_loss', model_loss)
                tf.summary.scalar('loss', loss)
                tf.summary.scalar('acc_top1', acc_top1)
                tf.summary.scalar('acc_top5', acc_top5)

                self.ft_step = tf.get_variable('finetune_step',
                                               shape=[],
                                               dtype=tf.int32,
                                               trainable=False)

            # optimizer & gradients
            init_lr, bnds, decay_rates, self.finetune_steps = \
                setup_bnds_decay_rates(self.model_name, self.dataset_name)
            lrn_rate = tf.train.piecewise_constant(self.ft_step, [i for i in bnds], \
                [init_lr * decay_rate for decay_rate in decay_rates])

            # optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
            optimizer = tf.train.AdamOptimizer(lrn_rate)
            if FLAGS.enbl_multi_gpu:
                optimizer = mgw.DistributedOptimizer(optimizer)

            # acquire non-cluster-vars and clusters.
            clusters = [v for v in self.trainable_vars if 'clusters' in v.name]
            rest_trainable_vars = [v for v in self.trainable_vars \
                if v not in clusters]

            # determine the var_list optimize
            if FLAGS.nuql_opt_mode in ['cluster', 'both']:
                if FLAGS.nuql_opt_mode == 'both':
                    optimizable_vars = self.trainable_vars
                else:
                    optimizable_vars = clusters
                if FLAGS.nuql_enbl_rl_agent:
                    optimizer_fintune = tf.train.GradientDescentOptimizer(
                        lrn_rate)
                    if FLAGS.enbl_multi_gpu:
                        optimizer_fintune = mgw.DistributedOptimizer(
                            optimizer_fintune)
                    grads_fintune = optimizer_fintune.compute_gradients(
                        loss, var_list=optimizable_vars)

            elif FLAGS.nuql_opt_mode == 'weights':
                optimizable_vars = rest_trainable_vars
            else:
                raise ValueError("Unknown optimization mode")

            grads = optimizer.compute_gradients(loss,
                                                var_list=optimizable_vars)

            # sm write graph
            self.sm_writer.add_graph(self.sess_train.graph)

            # define the ops
            with tf.control_dependencies(self.update_ops):
                self.ops['train'] = optimizer.apply_gradients(
                    grads, global_step=self.ft_step)
                if FLAGS.nuql_opt_mode in ['both', 'cluster'
                                           ] and FLAGS.nuql_enbl_rl_agent:
                    self.ops['rl_fintune'] = optimizer_fintune.apply_gradients(
                        grads_fintune, global_step=self.ft_step)
                else:
                    self.ops['rl_fintune'] = self.ops['train']
            self.ops['summary'] = tf.summary.merge_all()
            if FLAGS.enbl_dst:
                self.ops['log'] = [
                    lrn_rate, dst_loss, model_loss, loss, acc_top1, acc_top5
                ]
            else:
                self.ops['log'] = [
                    lrn_rate, model_loss, loss, acc_top1, acc_top5
                ]

            # NOTE: first run non_cluster_init_op, then cluster_init_op
            cluster_global_vars = [
                v for v in self.vars if 'clusters' in v.name
            ]
            # pay attention to beta1_power, beta2_power.
            non_cluster_global_vars = [
                v for v in tf.global_variables()
                if v not in cluster_global_vars
            ]

            self.ops['non_cluster_init'] = tf.variables_initializer(
                var_list=non_cluster_global_vars)
            self.ops['cluster_init'] = tf.variables_initializer(
                var_list=cluster_global_vars)
            self.ops['bcast'] = mgw.broadcast_global_variables(0) \
                if FLAGS.enbl_multi_gpu else None
            self.ops['reset_ft_step'] = tf.assign(self.ft_step, \
                tf.constant(0, dtype=tf.int32))

            self.saver_quant = tf.train.Saver(self.vars)
示例#5
0
    def __build_train(self):  # pylint: disable=too-many-locals
        """Build the training graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            if FLAGS.enbl_multi_gpu:
                config.gpu_options.visible_device_list = str(mgw.local_rank())  # pylint: disable=no-member
            else:
                config.gpu_options.visible_device_list = '0'  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(sess, images)

            # model definition - weight-sparsified model
            with tf.variable_scope(self.model_scope):
                # loss & extra evaluation metrics
                logits = self.forward_train(images)
                self.maskable_var_names = [
                    var.name for var in self.maskable_vars
                ]
                loss, metrics = self.calc_loss(labels, logits,
                                               self.trainable_vars)
                if FLAGS.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits, logits_dst)
                tf.summary.scalar('loss', loss)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # learning rate schedule
                self.global_step = tf.train.get_or_create_global_step()
                lrn_rate, self.nb_iters_train = setup_lrn_rate(
                    self.global_step, self.model_name, self.dataset_name)

                # overall pruning ratios of trainable & maskable variables
                pr_trainable = calc_prune_ratio(self.trainable_vars)
                pr_maskable = calc_prune_ratio(self.maskable_vars)
                tf.summary.scalar('pr_trainable', pr_trainable)
                tf.summary.scalar('pr_maskable', pr_maskable)

                # build masks and corresponding operations for weight sparsification
                self.masks, self.prune_op = self.__build_masks()

                # optimizer & gradients
                optimizer_base = tf.train.MomentumOptimizer(
                    lrn_rate, FLAGS.momentum)
                if not FLAGS.enbl_multi_gpu:
                    optimizer = optimizer_base
                else:
                    optimizer = mgw.DistributedOptimizer(optimizer_base)
                grads_origin = optimizer.compute_gradients(
                    loss, self.trainable_vars)
                grads_pruned = self.__calc_grads_pruned(grads_origin)

            # TF operations & model saver
            self.sess_train = sess
            with tf.control_dependencies(self.update_ops):
                self.train_op = optimizer.apply_gradients(
                    grads_pruned, global_step=self.global_step)
            self.summary_op = tf.summary.merge_all()
            self.log_op = [lrn_rate, loss, pr_trainable, pr_maskable] + list(
                metrics.values())
            self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_msk'] + list(
                metrics.keys())
            self.init_op = tf.variables_initializer(self.vars)
            self.init_opt_op = tf.variables_initializer(
                optimizer_base.variables())
            if FLAGS.enbl_multi_gpu:
                self.bcast_op = mgw.broadcast_global_variables(0)
            self.saver_train = tf.train.Saver(self.vars)
示例#6
0
    def __build_train(self):
        with tf.Graph().as_default():
            # TensorFlow session
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(mgw.local_rank() \
                if FLAGS.enbl_multi_gpu else 0)
            self.sess_train = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()
                images.set_shape((FLAGS.batch_size, images.shape[1],
                                  images.shape[2], images.shape[3]))

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(
                    self.sess_train, images)

            # model definition
            with tf.variable_scope(self.model_scope, reuse=tf.AUTO_REUSE):
                # forward pass
                logits = self.forward_train(images)

                self.weights = [
                    v for v in self.trainable_vars
                    if 'kernel' in v.name or 'weight' in v.name
                ]
                if not FLAGS.uql_quantize_all_layers:
                    self.weights = self.weights[1:-1]
                self.statistics['num_weights'] = \
                    [tf.reshape(v, [-1]).shape[0].value for v in self.weights]

                self.__quantize_train_graph()

                # loss & accuracy
                loss, metrics = self.calc_loss(labels, logits,
                                               self.trainable_vars)
                if self.dataset_name == 'cifar_10':
                    acc_top1, acc_top5 = metrics['accuracy'], tf.constant(0.)
                elif self.dataset_name == 'ilsvrc_12':
                    acc_top1, acc_top5 = metrics['acc_top1'], metrics[
                        'acc_top5']
                else:
                    raise ValueError("Unrecognized dataset name")

                model_loss = loss
                if FLAGS.enbl_dst:
                    dst_loss = self.helper_dst.calc_loss(logits, logits_dst)
                    loss += dst_loss
                    tf.summary.scalar('dst_loss', dst_loss)
                tf.summary.scalar('model_loss', model_loss)
                tf.summary.scalar('loss', loss)
                tf.summary.scalar('acc_top1', acc_top1)
                tf.summary.scalar('acc_top5', acc_top5)

                self.saver_train = tf.train.Saver(self.vars)

                self.ft_step = tf.get_variable('finetune_step',
                                               shape=[],
                                               dtype=tf.int32,
                                               trainable=False)

            # optimizer & gradients
            init_lr, bnds, decay_rates, self.finetune_steps = \
                setup_bnds_decay_rates(self.model_name, self.dataset_name)
            lrn_rate = tf.train.piecewise_constant(
                self.ft_step, [i for i in bnds],
                [init_lr * decay_rate for decay_rate in decay_rates])

            # optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
            optimizer = tf.train.AdamOptimizer(learning_rate=lrn_rate)
            if FLAGS.enbl_multi_gpu:
                optimizer = mgw.DistributedOptimizer(optimizer)
            grads = optimizer.compute_gradients(loss, self.trainable_vars)

            # sm write graph
            self.sm_writer.add_graph(self.sess_train.graph)

            with tf.control_dependencies(self.update_ops):
                self.ops['train'] = optimizer.apply_gradients(
                    grads, global_step=self.ft_step)
            self.ops['summary'] = tf.summary.merge_all()

            if FLAGS.enbl_dst:
                self.ops['log'] = [
                    lrn_rate, dst_loss, model_loss, loss, acc_top1, acc_top5
                ]
            else:
                self.ops['log'] = [
                    lrn_rate, model_loss, loss, acc_top1, acc_top5
                ]

            self.ops['reset_ft_step'] = tf.assign(
                self.ft_step, tf.constant(0, dtype=tf.int32))
            self.ops['init'] = tf.global_variables_initializer()
            self.ops['bcast'] = mgw.broadcast_global_variables(
                0) if FLAGS.enbl_multi_gpu else None
            self.saver_quant = tf.train.Saver(self.vars)
    def __build_train(self, model_helper):  # pylint: disable=too-many-locals
        """Build the training graph for the 'optimal' protocol.

    Args:
    * model_helper: model helper with definitions of model & dataset
    """

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator, __ = model_helper.build_dataset_train(
                    enbl_trn_val_split=True)
                images, labels = iterator.get_next()

            # model definition - full-precision network
            with tf.variable_scope(self.model_scope_full):
                logits = model_helper.forward_eval(
                    images)  # DO NOT USE forward_train() HERE!!!
                self.vars_full = get_vars_by_scope(self.model_scope_full)
                self.saver_full = tf.train.Saver(self.vars_full['all'])
                self.save_path_full = FLAGS.save_path

            # model definition - weight sparsified network
            with tf.variable_scope(self.model_scope_prnd):
                # forward pass & variables' saver
                logits = model_helper.forward_eval(
                    images)  # DO NOT USE forward_train() HERE!!!
                self.vars_prnd = get_vars_by_scope(self.model_scope_prnd)
                self.maskable_var_names = [
                    var.name for var in self.vars_prnd['maskable']
                ]
                loss, __ = model_helper.calc_loss(labels, logits,
                                                  self.vars_prnd['trainable'])
                self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all'])
                self.save_path_prnd = FLAGS.save_path.replace(
                    'models', 'models_pruned')

                # build masks for variable pruning
                self.masks, self.pr_all, self.pr_assign_op = self.__build_masks(
                )

            # create operations for initializing the weight sparsified network
            init_ops = []
            for var_full, var_prnd in zip(self.vars_full['all'],
                                          self.vars_prnd['all']):
                if var_full not in self.vars_full['maskable']:
                    init_ops += [var_prnd.assign(var_full)]
                else:
                    idx = self.vars_full['maskable'].index(var_full)
                    init_ops += [var_prnd.assign(var_full * self.masks[idx])]
            self.init_op = tf.group(init_ops)

            # build operations for layerwise regression & network fine-tuning
            self.rg_init_op, self.rg_train_ops = self.__build_layer_rg_ops()
            self.ft_init_op, self.ft_train_op = self.__build_network_ft_ops(
                loss)
            if FLAGS.enbl_multi_gpu:
                self.bcast_op = mgw.broadcast_global_variables(0)

            # create RL helper & agent on the primary worker
            if is_primary_worker('global'):
                self.rl_helper, self.agent = self.__build_rl_helper_n_agent(
                    sess)

            # TF operations
            self.sess_train = sess
示例#8
0
    def build_graph(self, is_train):
        with tf.Graph().as_default():
            # TensorFlow session
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.dataset_train.build(
                ) if is_train else self.dataset_eval.build()
                images, labels = iterator.get_next()
                if not isinstance(images, dict):
                    tf.add_to_collection('images_final', images)
                else:
                    tf.add_to_collection('images_final', images['image'])

            # model definition - primary model
            with tf.variable_scope(self.model_scope):
                # forward pass
                logits = self.hrnet.forward_train(
                    images) if is_train else self.hrnet.forward_eval(images)
                if not isinstance(logits, dict):
                    tf.add_to_collection('logits_final', logits)
                else:
                    for value in logits.values():
                        tf.add_to_collection('logits_final', value)

                # loss & extra evaluation metrics
                loss, metrics = self.hrnet.calc_loss(labels, logits,
                                                     self.trainable_vars)

                tf.summary.scalar('loss', loss)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # optimizer & gradients
                if is_train:
                    self.global_step = tf.train.get_or_create_global_step()
                    lrn_rate, self.nb_iters_train = self.setup_lrn_rate(
                        self.global_step)

                    optimizer = tf.train.MomentumOptimizer(
                        lrn_rate, self.hrnet.cfg['COMMON']['momentum'])
                    if FLAGS.enbl_multi_gpu:
                        optimizer = mgw.DistributedOptimizer(optimizer)
                    grads = optimizer.compute_gradients(
                        loss, self.trainable_vars)

            # TF operations & model saver
            if is_train:
                self.sess_train = sess

                with tf.control_dependencies(self.update_ops):
                    self.train_op = optimizer.apply_gradients(
                        grads, global_step=self.global_step)

                self.summary_op = tf.summary.merge_all()
                self.sm_writer = tf.summary.FileWriter(logdir=self.log_path)
                self.log_op = [lrn_rate, loss] + list(metrics.values())
                self.log_op_names = ['lr', 'loss'] + list(metrics.keys())
                self.init_op = tf.variables_initializer(self.vars)
                if FLAGS.enbl_multi_gpu:
                    self.bcast_op = mgw.broadcast_global_variables(0)
                self.saver_train = tf.train.Saver(self.vars)
            else:
                self.sess_eval = sess
                self.eval_op = [loss] + list(metrics.values())
                self.eval_op_names = ['loss'] + list(metrics.keys())
                self.saver_eval = tf.train.Saver(self.vars)
示例#9
0
  def __build_pruned_train_model(self, path=None, finetune=False): # pylint: disable=too-many-locals
    ''' build a training model from pruned model '''
    if path is None:
      path = FLAGS.save_path

    with tf.Graph().as_default():
      config = tf.ConfigProto()
      config.gpu_options.visible_device_list = str(# pylint: disable=no-member
        mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)
      self.sess_train = tf.Session(config=config)
      self.saver_train = tf.train.import_meta_graph(path + '.meta')
      self.saver_train.restore(self.sess_train, path)
      logits = tf.get_collection('logits')[0]
      train_images = tf.get_collection('train_images')[0]
      train_labels = tf.get_collection('train_labels')[0]
      mem_images = tf.get_collection('mem_images')[0]
      mem_labels = tf.get_collection('mem_labels')[0]

      self.sess_train.close()

      graph_editor.reroute_ts(train_images, mem_images)
      graph_editor.reroute_ts(train_labels, mem_labels)

      self.sess_train = tf.Session(config=config)
      self.saver_train.restore(self.sess_train, path)

      trainable_vars = self.trainable_vars
      loss, accuracy = self.calc_loss(train_labels, logits, trainable_vars)
      self.accuracy_keys = list(accuracy.keys())

      if FLAGS.enbl_dst:
        logits_dst = self.learner_dst.calc_logits(self.sess_train, train_images)
        loss += self.learner_dst.calc_loss(logits, logits_dst)

      tf.summary.scalar('loss', loss)
      for key in accuracy.keys():
        tf.summary.scalar(key, accuracy[key])
      self.summary_op = tf.summary.merge_all()

      global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32, trainable=False)
      self.global_step = global_step
      lrn_rate, self.nb_iters_train = setup_lrn_rate(
        self.global_step, self.model_name, self.dataset_name)

      if finetune and not FLAGS.cp_retrain:
        mom_optimizer = tf.train.AdamOptimizer(FLAGS.cp_lrn_rate_ft)
        self.log_op = [tf.constant(FLAGS.cp_lrn_rate_ft), loss, list(accuracy.values())]
      else:
        mom_optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
        self.log_op = [lrn_rate, loss, list(accuracy.values())]

      if FLAGS.enbl_multi_gpu:
        optimizer = mgw.DistributedOptimizer(mom_optimizer)
      else:
        optimizer = mom_optimizer
      grads_origin = optimizer.compute_gradients(loss, trainable_vars)
      grads_pruned, masks = self.__calc_grads_pruned(grads_origin)


      with tf.control_dependencies(self.update_ops):
        self.train_op = optimizer.apply_gradients(grads_pruned, global_step=global_step)

      self.sm_writer.add_graph(tf.get_default_graph())
      self.train_init_op = \
        tf.initialize_variables(mom_optimizer.variables() + [global_step] + masks)

      if FLAGS.enbl_multi_gpu:
        self.bcast_op = mgw.broadcast_global_variables(0)
示例#10
0
    def __build_train(self):  # pylint: disable=too-many-locals,too-many-statements
        """Build the training graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True  # pylint: disable=no-member
            config.gpu_options.visible_device_list = \
              str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(sess, images)

            # model definition - full model
            with tf.variable_scope(self.model_scope_full):
                __ = self.forward_train(images)
                self.vars_full = get_vars_by_scope(self.model_scope_full)
                self.saver_full = tf.train.Saver(self.vars_full['all'])
                self.save_path_full = FLAGS.save_path

            # model definition - channel-pruned model
            with tf.variable_scope(self.model_scope_prnd):
                logits_prnd = self.forward_train(images)
                self.vars_prnd = get_vars_by_scope(self.model_scope_prnd)
                self.conv_krnl_var_names = [
                    var.name for var in self.vars_prnd['conv_krnl']
                ]
                self.global_step = tf.train.get_or_create_global_step()
                self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all'] +
                                                       [self.global_step])

                # loss & extra evaluation metrics
                loss, metrics = self.calc_loss(labels, logits_prnd,
                                               self.vars_prnd['trainable'])
                if FLAGS.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits_prnd, logits_dst)
                tf.summary.scalar('loss', loss)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # learning rate schedule
                lrn_rate, self.nb_iters_train = self.setup_lrn_rate(
                    self.global_step)

                # calculate pruning ratios
                pr_trainable = calc_prune_ratio(self.vars_prnd['trainable'])
                pr_conv_krnl = calc_prune_ratio(self.vars_prnd['conv_krnl'])
                tf.summary.scalar('pr_trainable', pr_trainable)
                tf.summary.scalar('pr_conv_krnl', pr_conv_krnl)

                # create masks and corresponding operations for channel pruning
                self.masks = []
                mask_updt_ops = [
                ]  # update the mask based on convolutional kernel's value
                for idx, var in enumerate(self.vars_prnd['conv_krnl']):
                    tf.logging.info(
                        'creating a pruning mask for {} of size {}'.format(
                            var.name, var.shape))
                    mask_name = '/'.join(var.name.split('/')[1:]).replace(
                        ':0', '_mask')
                    mask_shape = [1, 1, var.shape[2], 1]  # 1 x 1 x c_in x 1
                    mask = tf.get_variable(mask_name,
                                           initializer=tf.ones(mask_shape),
                                           trainable=False)
                    var_norm = tf.reduce_sum(tf.square(var),
                                             axis=[0, 1, 3],
                                             keepdims=True)
                    self.masks += [mask]
                    mask_updt_ops += [
                        mask.assign(tf.cast(var_norm > 0.0, tf.float32))
                    ]
                self.mask_updt_op = tf.group(mask_updt_ops)

                # build operations for channel selection
                self.__build_chn_select_ops()

                # optimizer & gradients
                optimizer_base = tf.train.MomentumOptimizer(
                    lrn_rate, FLAGS.momentum)
                if not FLAGS.enbl_multi_gpu:
                    optimizer = optimizer_base
                else:
                    optimizer = mgw.DistributedOptimizer(optimizer_base)
                grads_origin = optimizer.compute_gradients(
                    loss, self.vars_prnd['trainable'])
                grads_pruned = self.__calc_grads_pruned(grads_origin)
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                               scope=self.model_scope_prnd)
                with tf.control_dependencies(update_ops):
                    self.train_op = optimizer.apply_gradients(
                        grads_pruned, global_step=self.global_step)

                # TF operations for initializing the channel-pruned model
                init_ops = []
                for var_full, var_prnd in zip(self.vars_full['all'],
                                              self.vars_prnd['all']):
                    init_ops += [var_prnd.assign(var_full)]
                init_ops += [self.global_step.initializer
                             ]  # initialize the global step
                init_ops += [
                    tf.variables_initializer(optimizer_base.variables())
                ]
                self.init_op = tf.group(init_ops)

            # TF operations for logging & summarizing
            self.sess_train = sess
            self.summary_op = tf.summary.merge_all()
            self.log_op = [lrn_rate, loss, pr_trainable, pr_conv_krnl] + list(
                metrics.values())
            self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_krn'] + list(
                metrics.keys())
            if FLAGS.enbl_multi_gpu:
                self.bcast_op = mgw.broadcast_global_variables(0)