Esempio n. 1
0
    def __build_train(self):  # pylint: disable=too-many-locals
        """Build the training graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            if FLAGS.enbl_multi_gpu:
                config.gpu_options.visible_device_list = str(mgw.local_rank())  # pylint: disable=no-member
            else:
                config.gpu_options.visible_device_list = '0'  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(sess, images)

            # model definition - weight-sparsified model
            with tf.variable_scope(self.model_scope):
                # loss & extra evaluation metrics
                logits = self.forward_train(images)
                self.maskable_var_names = [
                    var.name for var in self.maskable_vars
                ]
                loss, metrics = self.calc_loss(labels, logits,
                                               self.trainable_vars)
                if FLAGS.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits, logits_dst)
                tf.summary.scalar('loss', loss)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # learning rate schedule
                self.global_step = tf.train.get_or_create_global_step()
                lrn_rate, self.nb_iters_train = setup_lrn_rate(
                    self.global_step, self.model_name, self.dataset_name)

                # overall pruning ratios of trainable & maskable variables
                pr_trainable = calc_prune_ratio(self.trainable_vars)
                pr_maskable = calc_prune_ratio(self.maskable_vars)
                tf.summary.scalar('pr_trainable', pr_trainable)
                tf.summary.scalar('pr_maskable', pr_maskable)

                # build masks and corresponding operations for weight sparsification
                self.masks, self.prune_op = self.__build_masks()

                # optimizer & gradients
                optimizer_base = tf.train.MomentumOptimizer(
                    lrn_rate, FLAGS.momentum)
                if not FLAGS.enbl_multi_gpu:
                    optimizer = optimizer_base
                else:
                    optimizer = mgw.DistributedOptimizer(optimizer_base)
                grads_origin = optimizer.compute_gradients(
                    loss, self.trainable_vars)
                grads_pruned = self.__calc_grads_pruned(grads_origin)

            # TF operations & model saver
            self.sess_train = sess
            with tf.control_dependencies(self.update_ops):
                self.train_op = optimizer.apply_gradients(
                    grads_pruned, global_step=self.global_step)
            self.summary_op = tf.summary.merge_all()
            self.log_op = [lrn_rate, loss, pr_trainable, pr_maskable] + list(
                metrics.values())
            self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_msk'] + list(
                metrics.keys())
            self.init_op = tf.variables_initializer(self.vars)
            self.init_opt_op = tf.variables_initializer(
                optimizer_base.variables())
            if FLAGS.enbl_multi_gpu:
                self.bcast_op = mgw.broadcast_global_variables(0)
            self.saver_train = tf.train.Saver(self.vars)
Esempio n. 2
0
    def __build_train(self):  # pylint: disable=too-many-locals,too-many-statements
        """Build the training graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(sess, images)

            # model definition - full model
            with tf.variable_scope(self.model_scope_full):
                __ = self.forward_train(images)
                self.vars_full = get_vars_by_scope(self.model_scope_full)
                self.saver_full = tf.train.Saver(self.vars_full['all'])
                self.save_path_full = FLAGS.save_path

            # model definition - channel-pruned model
            with tf.variable_scope(self.model_scope_prnd):
                logits_prnd = self.forward_train(images)
                self.vars_prnd = get_vars_by_scope(self.model_scope_prnd)
                self.maskable_var_names = [
                    var.name for var in self.vars_prnd['maskable']
                ]
                self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all'])

                # loss & extra evaluation metrics
                loss_bsc, metrics = self.calc_loss(labels, logits_prnd,
                                                   self.vars_prnd['trainable'])
                if not FLAGS.enbl_dst:
                    loss_fnl = loss_bsc
                else:
                    loss_fnl = loss_bsc + self.helper_dst.calc_loss(
                        logits_prnd, logits_dst)
                tf.summary.scalar('loss_bsc', loss_bsc)
                tf.summary.scalar('loss_fnl', loss_fnl)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # learning rate schedule
                self.global_step = tf.train.get_or_create_global_step()
                lrn_rate, self.nb_iters_train = setup_lrn_rate(
                    self.global_step, self.model_name, self.dataset_name)

                # overall pruning ratios of trainable & maskable variables
                pr_trainable = calc_prune_ratio(self.vars_prnd['trainable'])
                pr_maskable = calc_prune_ratio(self.vars_prnd['maskable'])
                tf.summary.scalar('pr_trainable', pr_trainable)
                tf.summary.scalar('pr_maskable', pr_maskable)

                # create masks and corresponding operations for channel pruning
                self.masks = []
                self.mask_deltas = []
                self.mask_init_ops = []
                self.mask_updt_ops = []
                self.prune_ops = []
                for idx, var in enumerate(self.vars_prnd['maskable']):
                    name = '/'.join(var.name.split('/')[1:]).replace(
                        ':0', '_mask')
                    self.masks += [
                        tf.get_variable(name,
                                        initializer=tf.ones(var.shape),
                                        trainable=False)
                    ]
                    name = '/'.join(var.name.split('/')[1:]).replace(
                        ':0', '_mask_delta')
                    self.mask_deltas += [
                        tf.placeholder(tf.float32, shape=var.shape, name=name)
                    ]
                    self.mask_init_ops += [
                        self.masks[idx].assign(tf.zeros(var.shape))
                    ]
                    self.mask_updt_ops += [
                        self.masks[idx].assign_add(self.mask_deltas[idx])
                    ]
                    self.prune_ops += [var.assign(var * self.masks[idx])]

                # build extra losses for regression & discrimination
                self.reg_losses, self.dis_losses, self.idxs_layer_to_block = \
                  self.__build_extra_losses(labels)
                self.dis_losses += [
                    loss_bsc
                ]  # append discrimination-aware loss for the last block
                self.nb_layers = len(self.reg_losses)
                self.nb_blocks = len(self.dis_losses)
                for idx, reg_loss in enumerate(self.reg_losses):
                    tf.summary.scalar('reg_loss_%d' % idx, reg_loss)
                for idx, dis_loss in enumerate(self.dis_losses):
                    tf.summary.scalar('dis_loss_%d' % idx, dis_loss)

                # obtain the full list of trainable variables & update operations
                self.vars_all = tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES, scope=self.model_scope_prnd)
                self.trainable_vars_all = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES,
                    scope=self.model_scope_prnd)
                self.update_ops_all = tf.get_collection(
                    tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_prnd)

                # TF operations for initializing the channel-pruned model
                init_ops = []
                with tf.control_dependencies(
                    [tf.variables_initializer(self.vars_all)]):
                    for var_full, var_prnd in zip(self.vars_full['all'],
                                                  self.vars_prnd['all']):
                        init_ops += [var_prnd.assign(var_full)]
                self.init_op = tf.group(init_ops)

                # TF operations for layer-wise, block-wise, and whole-network fine-tuning
                self.layer_train_ops, self.layer_init_opt_ops, self.grad_norms = self.__build_layer_ops(
                )
                self.block_train_ops, self.block_init_opt_ops = self.__build_block_ops(
                )
                self.train_op, self.init_opt_op = self.__build_network_ops(
                    loss_fnl, lrn_rate)

            # TF operations for logging & summarizing
            self.sess_train = sess
            self.summary_op = tf.summary.merge_all()
            self.log_op = [lrn_rate, loss_fnl, pr_trainable, pr_maskable
                           ] + list(metrics.values())
            self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_msk'] + list(
                metrics.keys())
            if FLAGS.enbl_multi_gpu:
                self.bcast_op = mgw.broadcast_global_variables(0)
Esempio n. 3
0
    def __build_train(self):  # pylint: disable=too-many-locals,too-many-statements
        """Build the training graph."""

        with tf.Graph().as_default():
            # create a TF session for the current graph
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train()
                images, labels = iterator.get_next()
                images.set_shape((FLAGS.batch_size, images.shape[1],
                                  images.shape[2], images.shape[3]))

            # model definition - uniform quantized model - part 1
            with tf.variable_scope(self.model_scope_quan):
                logits_quan = self.forward_train(images)
                tf.contrib.quantize.experimental_create_training_graph(
                    weight_bits=FLAGS.uqtf_weight_bits,
                    activation_bits=FLAGS.uqtf_activation_bits,
                    quant_delay=FLAGS.uqtf_quant_delay,
                    freeze_bn_delay=FLAGS.uqtf_freeze_bn_delay,
                    scope=self.model_scope_quan)
                self.vars_quan = get_vars_by_scope(self.model_scope_quan)
                self.saver_quan_train = tf.train.Saver(self.vars_quan['all'])

            # model definition - distilled model
            if FLAGS.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(sess, images)

            # model definition - full model
            with tf.variable_scope(self.model_scope_full):
                __ = self.forward_train(images)
                self.vars_full = get_vars_by_scope(self.model_scope_full)
                self.saver_full = tf.train.Saver(self.vars_full['all'])
                self.save_path_full = FLAGS.save_path

            # model definition - uniform quantized model - part 2
            with tf.variable_scope(self.model_scope_quan):
                # loss & extra evaluation metrics
                loss_bsc, metrics = self.calc_loss(labels, logits_quan,
                                                   self.vars_quan['trainable'])
                if not FLAGS.enbl_dst:
                    loss_fnl = loss_bsc
                else:
                    loss_fnl = loss_bsc + self.helper_dst.calc_loss(
                        logits_quan, logits_dst)
                tf.summary.scalar('loss_bsc', loss_bsc)
                tf.summary.scalar('loss_fnl', loss_fnl)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # learning rate schedule
                self.global_step = tf.train.get_or_create_global_step()
                lrn_rate, self.nb_iters_train = setup_lrn_rate(
                    self.global_step, self.model_name, self.dataset_name)
                lrn_rate *= FLAGS.uqtf_lrn_rate_dcy

                # decrease the learning rate by a constant factor
                #if self.dataset_name == 'cifar_10':
                #  lrn_rate *= 1e-3
                #elif self.dataset_name == 'ilsvrc_12':
                #  lrn_rate *= 1e-4
                #else:
                #  raise ValueError('unrecognized dataset\'s name: ' + self.dataset_name)

                # obtain the full list of trainable variables & update operations
                self.vars_all = tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES, scope=self.model_scope_quan)
                self.trainable_vars_all = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES,
                    scope=self.model_scope_quan)
                self.update_ops_all = tf.get_collection(
                    tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_quan)

                # TF operations for initializing the uniform quantized model
                init_ops = []
                with tf.control_dependencies(
                    [tf.variables_initializer(self.vars_all)]):
                    for var_full, var_quan in zip(self.vars_full['all'],
                                                  self.vars_quan['all']):
                        init_ops += [var_quan.assign(var_full)]
                self.init_op = tf.group(init_ops)

                # TF operations for fine-tuning
                optimizer_base = tf.train.MomentumOptimizer(
                    lrn_rate, FLAGS.momentum)
                if not FLAGS.enbl_multi_gpu:
                    optimizer = optimizer_base
                else:
                    optimizer = mgw.DistributedOptimizer(optimizer_base)
                grads = optimizer.compute_gradients(loss_fnl,
                                                    self.trainable_vars_all)
                with tf.control_dependencies(self.update_ops_all):
                    self.train_op = optimizer.apply_gradients(
                        grads, global_step=self.global_step)
                self.init_opt_op = tf.variables_initializer(
                    optimizer_base.variables())

            # TF operations for logging & summarizing
            self.sess_train = sess
            self.summary_op = tf.summary.merge_all()
            self.log_op = [lrn_rate, loss_fnl] + list(metrics.values())
            self.log_op_names = ['lr', 'loss'] + list(metrics.keys())
            if FLAGS.enbl_multi_gpu:
                self.bcast_op = mgw.broadcast_global_variables(0)
Esempio n. 4
0
  def __build_pruned_train_model(self, path=None, finetune=False): # pylint: disable=too-many-locals
    ''' build a training model from pruned model '''
    if path is None:
      path = FLAGS.save_path

    with tf.Graph().as_default():
      config = tf.ConfigProto()
      config.gpu_options.visible_device_list = str(# pylint: disable=no-member
        mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)
      self.sess_train = tf.Session(config=config)
      self.saver_train = tf.train.import_meta_graph(path + '.meta')
      self.saver_train.restore(self.sess_train, path)
      logits = tf.get_collection('logits')[0]
      train_images = tf.get_collection('train_images')[0]
      train_labels = tf.get_collection('train_labels')[0]
      mem_images = tf.get_collection('mem_images')[0]
      mem_labels = tf.get_collection('mem_labels')[0]

      self.sess_train.close()

      graph_editor.reroute_ts(train_images, mem_images)
      graph_editor.reroute_ts(train_labels, mem_labels)

      self.sess_train = tf.Session(config=config)
      self.saver_train.restore(self.sess_train, path)

      trainable_vars = self.trainable_vars
      loss, accuracy = self.calc_loss(train_labels, logits, trainable_vars)
      self.accuracy_keys = list(accuracy.keys())

      if FLAGS.enbl_dst:
        logits_dst = self.learner_dst.calc_logits(self.sess_train, train_images)
        loss += self.learner_dst.calc_loss(logits, logits_dst)

      tf.summary.scalar('loss', loss)
      for key in accuracy.keys():
        tf.summary.scalar(key, accuracy[key])
      self.summary_op = tf.summary.merge_all()

      global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32, trainable=False)
      self.global_step = global_step
      lrn_rate, self.nb_iters_train = setup_lrn_rate(
        self.global_step, self.model_name, self.dataset_name)

      if finetune and not FLAGS.cp_retrain:
        mom_optimizer = tf.train.AdamOptimizer(FLAGS.cp_lrn_rate_ft)
        self.log_op = [tf.constant(FLAGS.cp_lrn_rate_ft), loss, list(accuracy.values())]
      else:
        mom_optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
        self.log_op = [lrn_rate, loss, list(accuracy.values())]

      if FLAGS.enbl_multi_gpu:
        optimizer = mgw.DistributedOptimizer(mom_optimizer)
      else:
        optimizer = mom_optimizer
      grads_origin = optimizer.compute_gradients(loss, trainable_vars)
      grads_pruned, masks = self.__calc_grads_pruned(grads_origin)


      with tf.control_dependencies(self.update_ops):
        self.train_op = optimizer.apply_gradients(grads_pruned, global_step=global_step)

      self.sm_writer.add_graph(tf.get_default_graph())
      self.train_init_op = \
        tf.initialize_variables(mom_optimizer.variables() + [global_step] + masks)

      if FLAGS.enbl_multi_gpu:
        self.bcast_op = mgw.broadcast_global_variables(0)
Esempio n. 5
0
    def __build(self, is_train):  # pylint: disable=too-many-locals
        """Build the training / evaluation graph.

        Args:
        * is_train: whether to create the training graph
        """

        with tf.Graph().as_default():
            # TensorFlow session
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
            sess = tf.Session(config=config)

            # data input pipeline
            with tf.variable_scope(self.data_scope):
                iterator = self.build_dataset_train(
                ) if is_train else self.build_dataset_eval()
                images, labels = iterator.get_next()
                tf.add_to_collection('images_final', images)

            # model definition - distilled model
            if self.enbl_dst:
                logits_dst = self.helper_dst.calc_logits(sess, images)

            # model definition - primary model
            with tf.variable_scope(self.model_scope):
                # forward pass
                logits = self.forward_train(
                    images) if is_train else self.forward_eval(images)
                # out = self.forward_eval(images)
                tf.add_to_collection('logits_final', logits)

                # loss & extra evalution metrics
                loss, metrics = self.calc_loss(labels, logits,
                                               self.trainable_vars)
                if self.enbl_dst:
                    loss += self.helper_dst.calc_loss(logits, logits_dst)
                tf.summary.scalar('loss', loss)
                for key, value in metrics.items():
                    tf.summary.scalar(key, value)

                # optimizer & gradients
                if is_train:
                    self.global_step = tf.train.get_or_create_global_step()
                    lrn_rate, self.nb_iters_train = setup_lrn_rate(
                        self.global_step, self.model_name, self.dataset_name)
                    optimizer = tf.train.MomentumOptimizer(
                        lrn_rate, FLAGS.momentum)
                    # optimizer = tf.train.AdamOptimizer(lrn_rate)
                    if FLAGS.enbl_multi_gpu:
                        optimizer = mgw.DistributedOptimizer(optimizer)
                    grads = optimizer.compute_gradients(
                        loss, self.trainable_vars)

            # TF operations & model saver
            if is_train:
                self.sess_train = sess
                with tf.control_dependencies(self.update_ops):
                    self.train_op = optimizer.apply_gradients(
                        grads, global_step=self.global_step)
                self.summary_op = tf.summary.merge_all()
                self.log_op = [lrn_rate, loss] + list(metrics.values())
                self.log_op_names = ['lr', 'loss'] + list(metrics.keys())
                self.init_op = tf.variables_initializer(self.vars)
                if FLAGS.enbl_multi_gpu:
                    self.bcast_op = mgw.broadcast_global_variables(0)
                self.saver_train = tf.train.Saver(self.vars)
            else:
                self.sess_eval = sess

                self.factory_op = [tf.cast(logits, tf.uint8)]
                self.time_op = [logits]
                self.out_op = [
                    tf.cast(images, tf.uint8),
                    tf.cast(logits, tf.uint8),
                    tf.cast(labels, tf.uint8)
                ]
                self.eval_op = [loss] + list(metrics.values())
                self.eval_op_names = ['loss'] + list(metrics.keys())
                self.saver_eval = tf.train.Saver(self.vars)
Esempio n. 6
0
  def __build_train(self):  # pylint: disable=too-many-locals,too-many-statements
    """Build the training graph."""

    with tf.Graph().as_default():
      # create a TF session for the current graph
      config = tf.ConfigProto()
      config.gpu_options.allow_growth = True  # pylint: disable=no-member
      config.gpu_options.visible_device_list = \
        str(mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)  # pylint: disable=no-member
      sess = tf.Session(config=config)

      # data input pipeline
      with tf.variable_scope(self.data_scope):
        iterator = self.build_dataset_train()
        images, labels = iterator.get_next()

      # model definition - distilled model
      if FLAGS.enbl_dst:
        logits_dst = self.helper_dst.calc_logits(sess, images)

      # model definition - channel-pruned model
      with tf.variable_scope(self.model_scope_prnd):
        logits_prnd = self.forward_train(images)
        self.vars_prnd = get_vars_by_scope(self.model_scope_prnd)
        self.global_step = tf.train.get_or_create_global_step()
        self.saver_prnd_train = tf.train.Saver(self.vars_prnd['all'] + [self.global_step])

        # loss & extra evaluation metrics
        loss, metrics = self.calc_loss(labels, logits_prnd, self.vars_prnd['trainable'])
        if FLAGS.enbl_dst:
          loss += self.helper_dst.calc_loss(logits_prnd, logits_dst)
        tf.summary.scalar('loss', loss)
        for key, value in metrics.items():
          tf.summary.scalar(key, value)

        # learning rate schedule
        # lrn_rate, self.nb_iters_train = self.setup_lrn_rate(self.global_step)
        lrn_rate, self.nb_iters_train = setup_lrn_rate(
          self.global_step, self.model_name, self.dataset_name)

        # calculate pruning ratios
        pr_trainable = calc_prune_ratio(self.vars_prnd['trainable'])
        pr_conv_krnl = calc_prune_ratio(self.vars_prnd['conv_krnl'])
        tf.summary.scalar('pr_trainable', pr_trainable)
        tf.summary.scalar('pr_conv_krnl', pr_conv_krnl)

        # create masks and corresponding operations for channel pruning
        self.masks = []
        for idx, var in enumerate(self.vars_prnd['conv_krnl']):
          tf.logging.info('creating a pruning mask for {} of size {}'.format(var.name, var.shape))
          mask_name = '/'.join(var.name.split('/')[1:]).replace(':0', '_mask')
          var_norm = tf.reduce_sum(tf.square(var), axis=[0, 1, 3], keepdims=True)
          mask_init = tf.cast(var_norm > 0.0, tf.float32)
          mask = tf.get_variable(mask_name, initializer=mask_init, trainable=False)
          self.masks += [mask]

        # optimizer & gradients
        optimizer_base = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
        if not FLAGS.enbl_multi_gpu:
          optimizer = optimizer_base
        else:
          optimizer = mgw.DistributedOptimizer(optimizer_base)
        grads_origin = optimizer.compute_gradients(loss, self.vars_prnd['trainable'])
        grads_pruned = self.__calc_grads_pruned(grads_origin)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.model_scope_prnd)
        with tf.control_dependencies(update_ops):
          self.train_op = optimizer.apply_gradients(grads_pruned, global_step=self.global_step)

      # TF operations for logging & summarizing
      self.sess_train = sess
      self.summary_op = tf.summary.merge_all()
      self.init_op = tf.group(
        tf.variables_initializer([self.global_step] + self.masks + optimizer_base.variables()))
      self.log_op = [lrn_rate, loss, pr_trainable, pr_conv_krnl] + list(metrics.values())
      self.log_op_names = ['lr', 'loss', 'pr_trn', 'pr_krn'] + list(metrics.keys())
      if FLAGS.enbl_multi_gpu:
        self.bcast_op = mgw.broadcast_global_variables(0)