示例#1
0
    def testWeightSpecificSparsity(self):
        param_list = [
            "begin_pruning_step=1", "pruning_frequency=1",
            "end_pruning_step=100", "target_sparsity=0.5",
            "weight_sparsity_map=[layer2/weights:0.75]", "threshold_decay=0.0"
        ]
        test_spec = ",".join(param_list)
        pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

        with variable_scope.variable_scope("layer1"):
            w1 = variables.Variable(math_ops.linspace(1.0, 100.0, 100),
                                    name="weights")
            _ = pruning.apply_mask(w1)
        with variable_scope.variable_scope("layer2"):
            w2 = variables.Variable(math_ops.linspace(1.0, 100.0, 100),
                                    name="weights")
            _ = pruning.apply_mask(w2)

        p = pruning.Pruning(pruning_hparams)
        mask_update_op = p.conditional_mask_update_op()
        increment_global_step = state_ops.assign_add(self.global_step, 1)

        with self.test_session() as session:
            variables.global_variables_initializer().run()
            for _ in range(110):
                session.run(mask_update_op)
                session.run(increment_global_step)

            self.assertAllEqual(session.run(pruning.get_weight_sparsity()),
                                [0.5, 0.75])
示例#2
0
 def testConditionalMaskUpdate(self):
     param_list = [
         "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6"
     ]
     test_spec = ",".join(param_list)
     pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
     weights = variables.Variable(math_ops.linspace(1.0, 100.0, 100),
                                  name="weights")
     masked_weights = pruning.apply_mask(weights)
     sparsity = variables.Variable(0.00, name="sparsity")
     # Set up pruning
     p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
     p._spec.threshold_decay = 0.0
     mask_update_op = p.conditional_mask_update_op()
     sparsity_val = math_ops.linspace(0.0, 0.9, 10)
     increment_global_step = state_ops.assign_add(self.global_step, 1)
     non_zero_count = []
     with self.test_session() as session:
         variables.global_variables_initializer().run()
         for i in range(10):
             session.run(state_ops.assign(sparsity, sparsity_val[i]))
             session.run(mask_update_op)
             session.run(increment_global_step)
             non_zero_count.append(np.count_nonzero(masked_weights.eval()))
     # Weights pruned at steps 0,2,4,and,6
     expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40]
     self.assertAllEqual(expected_non_zero_count, non_zero_count)
示例#3
0
    def testPerLayerBlockSparsity(self):
        param_list = [
            "block_dims_map=[layer1/weights:1x1,layer2/weights:1x2]",
            "block_pooling_function=AVG", "threshold_decay=0.0"
        ]

        test_spec = ",".join(param_list)
        pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

        with variable_scope.variable_scope("layer1"):
            w1 = constant_op.constant([[-0.1, 0.1], [-0.2, 0.2]],
                                      name="weights")
            pruning.apply_mask(w1)

        with variable_scope.variable_scope("layer2"):
            w2 = constant_op.constant(
                [[0.1, 0.1, 0.3, 0.3], [0.2, 0.2, 0.4, 0.4]], name="weights")
            pruning.apply_mask(w2)

        sparsity = variables.VariableV1(0.5, name="sparsity")

        p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
        mask_update_op = p.mask_update_op()
        with self.cached_session() as session:
            variables.global_variables_initializer().run()
            session.run(mask_update_op)
            mask1_eval = session.run(pruning.get_masks()[0])
            mask2_eval = session.run(pruning.get_masks()[1])

            self.assertAllEqual(session.run(pruning.get_weight_sparsity()),
                                [0.5, 0.5])

            self.assertAllEqual(mask1_eval, [[0.0, 0.0], [1., 1.]])
            self.assertAllEqual(mask2_eval, [[0, 0, 1., 1.], [0, 0, 1., 1.]])
示例#4
0
    def _setup_graph(self):
        '''
        '''
        default_dict = {
            'name': 'model_pruining',
            'begin_pruning_step': 0,
            'end_pruning_step': 34400,
            'target_sparsity': 0.31,
            'pruning_frequency': 344,
            'sparsity_function_begin_step': 0,
            'sparsity_function_end_step': 34400,
            'sparsity_function_exponent': 2,
        }
        for k, v in self.param_dict.items():
            if k in default_dict:
                default_dict[k] = v

        param_list = ['{}={}'.format(k, v) for k, v in default_dict.items()]
        # param_list = [
        #         "name=cifar10_pruning",
        #         "begin_pruning_step=1000",
        #         "end_pruning_step=20000",
        #         "target_sparsity=0.9",
        #         "sparsity_function_begin_step=1000",
        #         "sparsity_function_end_step=20000"
        # ]

        PRUNE_HPARAMS = ",".join(param_list)
        pruning_hparams = pruning.get_pruning_hparams().parse(PRUNE_HPARAMS)
        self.p = pruning.Pruning(pruning_hparams,
                                 global_step=get_global_step_var())
        self.p.add_pruning_summaries()
        self.mask_update_op = self.p.conditional_mask_update_op()
示例#5
0
 def testConditionalMaskUpdate(self):
   param_list = [
       "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6"
   ]
   test_spec = ",".join(param_list)
   pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
   weights = variables.Variable(
       math_ops.linspace(1.0, 100.0, 100), name="weights")
   masked_weights = pruning.apply_mask(weights)
   sparsity = variables.Variable(0.00, name="sparsity")
   # Set up pruning
   p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
   p._spec.threshold_decay = 0.0
   mask_update_op = p.conditional_mask_update_op()
   sparsity_val = math_ops.linspace(0.0, 0.9, 10)
   increment_global_step = state_ops.assign_add(self.global_step, 1)
   non_zero_count = []
   with self.test_session() as session:
     variables.global_variables_initializer().run()
     for i in range(10):
       session.run(state_ops.assign(sparsity, sparsity_val[i]))
       session.run(mask_update_op)
       session.run(increment_global_step)
       non_zero_count.append(np.count_nonzero(masked_weights.eval()))
   # Weights pruned at steps 0,2,4,and,6
   expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40]
   self.assertAllEqual(expected_non_zero_count, non_zero_count)
示例#6
0
  def testWeightSpecificSparsity(self):
    param_list = [
        "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100",
        "target_sparsity=0.5", "weight_sparsity_map=[layer2/weights:0.75]",
        "threshold_decay=0.0"
    ]
    test_spec = ",".join(param_list)
    pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

    with variable_scope.variable_scope("layer1"):
      w1 = variables.Variable(
          math_ops.linspace(1.0, 100.0, 100), name="weights")
      _ = pruning.apply_mask(w1)
    with variable_scope.variable_scope("layer2"):
      w2 = variables.Variable(
          math_ops.linspace(1.0, 100.0, 100), name="weights")
      _ = pruning.apply_mask(w2)

    p = pruning.Pruning(pruning_hparams)
    mask_update_op = p.conditional_mask_update_op()
    increment_global_step = state_ops.assign_add(self.global_step, 1)

    with self.cached_session() as session:
      variables.global_variables_initializer().run()
      for _ in range(110):
        session.run(mask_update_op)
        session.run(increment_global_step)

      self.assertAllEqual(
          session.run(pruning.get_weight_sparsity()), [0.5, 0.75])
示例#7
0
 def __init__(self,
              input_size,
              output_size,
              model_path: str,
              momentum=0.9,
              reg_str=0.0005,
              scope='ConvNet',
              pruning_start=int(10e4),
              pruning_end=int(10e5),
              pruning_freq=int(10),
              sparsity_start=0,
              sparsity_end=int(10e5),
              target_sparsity=0.0,
              dropout=0.5,
              initial_sparsity=0,
              wd=0.0):
     super(ConvNet, self).__init__(input_size=input_size,
                                   output_size=output_size,
                                   model_path=model_path)
     self.scope = scope
     self.momentum = momentum
     self.reg_str = reg_str
     self.dropout = dropout
     self.logger = get_logger(scope)
     self.wd = wd
     self.logger.info("creating graph...")
     with self.graph.as_default():
         self.global_step = tf.Variable(0, trainable=False)
         self._build_placeholders()
         self.logits = self._build_model()
         self.weights_matrices = pruning.get_masked_weights()
         self.sparsity = pruning.get_weight_sparsity()
         self.loss = self._loss()
         self.train_op = self._optimizer()
         self._create_metrics()
         self.saver = tf.train.Saver(var_list=tf.global_variables())
         self.hparams = pruning.get_pruning_hparams()\
             .parse('name={}, begin_pruning_step={}, end_pruning_step={}, target_sparsity={},'
                    ' sparsity_function_begin_step={},sparsity_function_end_step={},'
                    'pruning_frequency={},initial_sparsity={},'
                    ' sparsity_function_exponent={}'.format(scope,
                                                            pruning_start,
                                                            pruning_end,
                                                            target_sparsity,
                                                            sparsity_start,
                                                            sparsity_end,
                                                            pruning_freq,
                                                            initial_sparsity,
                                                            3))
         # note that the global step plays an important part in the pruning mechanism,
         # the higher the global step the closer the sparsity is to sparsity end
         self.pruning_obj = pruning.Pruning(self.hparams,
                                            global_step=self.global_step)
         self.mask_update_op = self.pruning_obj.conditional_mask_update_op()
         # the pruning objects defines the pruning mechanism, via the mask_update_op the model gets pruned
         # the pruning takes place at each training epoch and it objective to achieve the sparsity end HP
         self.init_variables(
             tf.global_variables())  # initialize variables in graph
  def _prune_model(self, session):
    pruning_hparams = pruning.get_pruning_hparams().parse(self.pruning_spec)
    p = pruning.Pruning(pruning_hparams, sparsity=self.sparsity)
    self.mask_update_op = p.conditional_mask_update_op()

    variables.global_variables_initializer().run()
    for _ in range(20):
      session.run(self.mask_update_op)
      session.run(self.increment_global_step)
示例#9
0
 def setUp(self):
     super(PruningHParamsTest, self).setUp()
     # Add global step variable to the graph
     self.global_step = training_util.get_or_create_global_step()
     # Add sparsity
     self.sparsity = variables.Variable(0.5, name="sparsity")
     # Parse hparams
     self.pruning_hparams = pruning.get_pruning_hparams().parse(
         self.TEST_HPARAMS)
示例#10
0
 def setUp(self):
   super(PruningHParamsTest, self).setUp()
   # Add global step variable to the graph
   self.global_step = training_util.get_or_create_global_step()
   # Add sparsity
   self.sparsity = variables.Variable(0.5, name="sparsity")
   # Parse hparams
   self.pruning_hparams = pruning.get_pruning_hparams().parse(
       self.TEST_HPARAMS)
    def _prune_model(self, session):
        pruning_hparams = pruning.get_pruning_hparams().parse(
            self.pruning_spec)
        p = pruning.Pruning(pruning_hparams, sparsity=self.sparsity)
        self.mask_update_op = p.conditional_mask_update_op()

        variables.global_variables_initializer().run()
        for _ in range(20):
            session.run(self.mask_update_op)
            session.run(self.increment_global_step)
示例#12
0
    def __init__(self, model, data_handle, hyperparams):
        self.model = model
        self.data_handle = data_handle
        self.hyperparams = hyperparams

        # get defined tensor
        self.X = self.model.X
        self.Y = self.model.Y
        self.result = self.model.result
        self.train  = self.model.Utils.is_train
        self.update = self.model.Utils.tensor_updated
        self.learning_rate = tf.placeholder(tf.float32)
        self.global_step = tf.Variable(0, dtype = tf.int32, trainable = False)
        self.weights_decay = self.hyperparams['weights_decay']
        self.global_step_update = tf.assign_add(self.global_step, tf.constant(2, dtype = tf.int32))

        # optimizer
        self.cross_entropy     = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = self.Y, logits = self.result))
        self.l2_loss           = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()])
        self.loss              = self.l2_loss * self.weights_decay + self.cross_entropy
        # train_step        = tf.train.AdamOptimizer(learning_rate = 0.001).minimize(loss)
        self.train_step        = tf.train.MomentumOptimizer(self.learning_rate, 0.9, use_nesterov = True).minimize(self.loss)
        self.top1              = tf.equal(tf.argmax(self.result, 1), tf.argmax(self.Y, 1))
        self.top1_acc          = tf.reduce_mean(tf.cast(self.top1, "float"))
        self.top5              = tf.nn.in_top_k(predictions = self.result, targets = tf.argmax(self.Y, 1), k = 5) 
        self.top5_acc          = tf.reduce_mean(tf.cast(self.top5, "float"))


        # prune
        if self.hyperparams['enable_prune']:
            pruning_hparams = pruning.get_pruning_hparams()
            pruning_hparams.begin_pruning_step = self.hyperparams['begin_pruning_step']
            pruning_hparams.end_pruning_step   = self.hyperparams['end_pruning_step']
            pruning_hparams.pruning_frequency  = self.hyperparams['pruning_frequency']
            pruning_hparams.target_sparsity    = self.hyperparams['target_sparsity']
            p = pruning.Pruning(pruning_hparams, global_step = self.global_step)
            self.prune_op = p.conditional_mask_update_op()

        # log
        log_prefix = "log" + "_quant_{}".format(self.hyperparams['quant_bits']) + "_prune_{}".format(str(self.hyperparams["enable_prune"])) + "/"
        if not os.path.exists(log_prefix):
            os.mkdir(log_prefix)
        self.fd = open(log_prefix + self.hyperparams['model_name'], "a")
        print("model_name = {}, quant_bits = {}, enable_prune = {}".format(self.hyperparams['model_name'], self.hyperparams['quant_bits'], self.hyperparams['target_sparsity']), file = self.fd)
        print(time.asctime(time.localtime(time.time())) + "   train started", file = self.fd)


        # init_variable
        # config = tf.ConfigProto()
        # config.gpu_options.allow_growth = True
        # config.gpu_options.per_process_gpu_memory_fraction = 0.6
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
示例#13
0
 def __init__(self,
              actor_input_dim,
              actor_output_dim,
              model_path,
              redundancy=None,
              last_measure=10e4,
              tau=0.01):
     super(StudentActor, self).__init__(model_path=model_path)
     self.actor_input_dim = (None, actor_input_dim)
     self.actor_output_dim = (None, actor_output_dim)
     self.tau = tau
     self.redundancy = redundancy
     self.last_measure = last_measure
     with self.graph.as_default():
         self.actor_global_step = tf.Variable(0, trainable=False)
         self._build_placeholders()
         self.actor_logits = self._build_actor()
         # self.gumbel_dist = self._build_gumbel(self.actor_logits)
         self.loss = self._build_loss()
         self.actor_parameters = tf.get_collection(
             tf.GraphKeys.TRAINABLE_VARIABLES, scope='actor')
         self.actor_pruned_weight_matrices = pruning.get_masked_weights()
         self.actor_train_op = self._build_actor_train_op()
         self.actor_saver = tf.train.Saver(var_list=self.actor_parameters,
                                           max_to_keep=100)
         self.init_variables(tf.global_variables())
         self.sparsity = pruning.get_weight_sparsity()
         self.hparams = pruning.get_pruning_hparams() \
             .parse('name={}, begin_pruning_step={}, end_pruning_step={}, target_sparsity={},'
                    ' sparsity_function_begin_step={},sparsity_function_end_step={},'
                    'pruning_frequency={},initial_sparsity={},'
                    ' sparsity_function_exponent={}'.format('Actor',
                                                            cfg.pruning_start,
                                                            cfg.pruning_end,
                                                            cfg.target_sparsity,
                                                            cfg.sparsity_start,
                                                            cfg.sparsity_end,
                                                            cfg.pruning_freq,
                                                            cfg.initial_sparsity,
                                                            3))
         # note that the global step plays an important part in the pruning mechanism,
         # the higher the global step the closer the sparsity is to sparsity end
         self.pruning_obj = pruning.Pruning(
             self.hparams, global_step=self.actor_global_step)
         self.mask_update_op = self.pruning_obj.conditional_mask_update_op()
         # the pruning objects defines the pruning mechanism, via the mask_update_op the model gets pruned
         # the pruning takes place at each training epoch and it objective to achieve the sparsity end HP
         self.init_variables(
             tf.global_variables())  # initialize variables in graph
示例#14
0
    def _blockMasking(self, hparams, weights, expected_mask):

        threshold = variables.Variable(0.0, name="threshold")
        sparsity = variables.Variable(0.5, name="sparsity")
        test_spec = ",".join(hparams)
        pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

        # Set up pruning
        p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
        with self.test_session():
            variables.global_variables_initializer().run()
            _, new_mask = p._maybe_update_block_mask(weights, threshold)
            # Check if the mask is the same size as the weights
            self.assertAllEqual(new_mask.get_shape(), weights.get_shape())
            mask_val = new_mask.eval()
            self.assertAllEqual(mask_val, expected_mask)
示例#15
0
def set_prune_params(s):
    # Get, Print, and Edit Pruning Hyperparameters
    pruning_hparams = pruning.get_pruning_hparams()
    print("Pruning Hyperparameters:", pruning_hparams)

    # Change hyperparameters to meet our needs
    pruning_hparams.begin_pruning_step = 0
    pruning_hparams.end_pruning_step = 250
    pruning_hparams.pruning_frequency = 1
    pruning_hparams.sparsity_function_end_step = 250
    pruning_hparams.target_sparsity = s

    # Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam
    p = pruning.Pruning(pruning_hparams, global_step=global_step)
    prune_op = p.conditional_mask_update_op()
    return prune_op
示例#16
0
  def _blockMasking(self, hparams, weights, expected_mask):

    threshold = variables.Variable(0.0, name="threshold")
    sparsity = variables.Variable(0.51, name="sparsity")
    test_spec = ",".join(hparams)
    pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

    # Set up pruning
    p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
    with self.test_session():
      variables.global_variables_initializer().run()
      _, new_mask = p._maybe_update_block_mask(weights, threshold)
      # Check if the mask is the same size as the weights
      self.assertAllEqual(new_mask.get_shape(), weights.get_shape())
      mask_val = new_mask.eval()
      self.assertAllEqual(mask_val, expected_mask)
示例#17
0
def pruning_params(global_step, begin_step=0, end_step=-1, pruning_freq=10,
                   sparsity_function=2000, target_sparsity=.50, sparsity_exponent=1.0):
    """
    Creates the pruning op
    :param global_step: the global step, needed for pruning
    :param begin_step: the global step at which to begin pruning
    :param end_step: the global step at which to end pruning
    :param pruning_freq: the frequency of global step for when to prune
    :param sparsity_function: the global step used as the end point for the gradual sparsity function
    :param target_sparsity: the target sparsity
    :param sparsity_exponent: the exponent for the sparsity function
    :return: Pruning op
    """
    pruning_hparams = pruning.get_pruning_hparams()
    pruning_hparams.begin_pruning_step = begin_step
    pruning_hparams.end_pruning_step = end_step
    pruning_hparams.pruning_frequency = pruning_freq
    pruning_hparams.sparsity_function_end_step = sparsity_function
    pruning_hparams.target_sparsity = target_sparsity
    pruning_hparams.sparsity_function_exponent = sparsity_exponent
    p = pruning.Pruning(pruning_hparams, global_step=global_step, sparsity=target_sparsity)
    p_op = p.conditional_mask_update_op()
    p.add_pruning_summaries()
    return p_op
示例#18
0
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, prune_config_flag):
  """Creates an optimizer training op."""
  global_step = tf.train.get_or_create_global_step()

  learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)

  # Implements linear decay of the learning rate.
  learning_rate = tf.train.polynomial_decay(
      learning_rate,
      global_step,
      num_train_steps,
      end_learning_rate=0.0,
      power=1.0,
      cycle=False)

  # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
  # learning rate will be `global_step/num_warmup_steps * init_lr`.
  if num_warmup_steps:
    global_steps_int = tf.cast(global_step, tf.int32)
    warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)

    global_steps_float = tf.cast(global_steps_int, tf.float32)
    warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)

    warmup_percent_done = global_steps_float / warmup_steps_float
    warmup_learning_rate = init_lr * warmup_percent_done

    is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
    learning_rate = (
        (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)

  # It is recommended that you use this optimizer for fine tuning, since this
  # is how the model was trained (note that the Adam m/v variables are NOT
  # loaded from init_checkpoint.)
  optimizer = AdamWeightDecayOptimizer(
      learning_rate=learning_rate,
      weight_decay_rate=0.01,
      beta_1=0.9,
      beta_2=0.999,
      epsilon=1e-6,
      exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])

  if use_tpu:
    optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

  # memory_saving_gradients.DEBUG_LOGGING = True
  tvars = tf.trainable_variables()
  if os.getenv('DISABLE_GRAD_CHECKPOINT'):
    grads = tf.gradients(loss, tvars)
  else:
    grads = memory_saving_gradients.gradients(loss, tvars, checkpoints='memory')

  # This is how the model was pre-trained.
  (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)

  train_op = optimizer.apply_gradients(
      zip(grads, tvars), global_step=global_step)

  # Pruning mask update ops
  if prune_config_flag:
    tf.logging.info(f'Pruning with configs {prune_config_flag}')
    prune_config =  get_pruning_hparams().parse(prune_config_flag)
    prune = Pruning(prune_config, global_step=global_step)
    mask_update_op = prune.conditional_mask_update_op()
    prune.add_pruning_summaries()
  else:
    tf.logging.info('No pruning config provided, skipping pruning')
    mask_update_op = tf.no_op()

  # Normally the global step update is done inside of `apply_gradients`.
  # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
  # a different optimizer, you should probably take this line out.
  new_global_step = global_step + 1
  train_op = tf.group(train_op, mask_update_op, [global_step.assign(new_global_step)])
  return train_op
示例#19
0
def train():
    is_training = True
    # data pipeline
    imgs, true_boxes = gen_data_batch(re.sub(r'examples/', '', cfg.data_path),
                                      cfg.batch_size * cfg.train.num_gpus)
    imgs_split = tf.split(imgs, cfg.train.num_gpus)
    true_boxes_split = tf.split(true_boxes, cfg.train.num_gpus)

    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0.),
                                  trainable=False)
    lr = tf.train.piecewise_constant(global_step, cfg.train.lr_steps,
                                     cfg.train.learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate=lr)

    # Calculate the gradients for each model tower.
    tower_grads = []
    summaries_buf = []
    summaries = set()
    with tf.variable_scope(tf.get_variable_scope()):
        for i in range(cfg.train.num_gpus):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('%s_%d' % (cfg.train.tower, i)) as scope:
                    model = PDetNet(imgs_split[i], true_boxes_split[i],
                                    is_training)
                    loss = model.compute_loss()
                    tf.get_variable_scope().reuse_variables()
                    grads_and_vars = optimizer.compute_gradients(loss)
                    #
                    gradients_norm = summaries_gradients_norm(grads_and_vars)
                    gradients_hist = summaries_gradients_hist(grads_and_vars)
                    #summaries_buf.append(gradients_norm)
                    summaries_buf.append(gradients_hist)
                    ##sum_set = set()
                    ##sum_set.add(tf.summary.scalar("loss", loss))
                    ##summaries_buf.append(sum_set)
                    summaries_buf.append({tf.summary.scalar("loss", loss)})
                    #
                    tower_grads.append(grads_and_vars)
                    if i == 0:
                        current_loss = loss
                        update_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                        vars_det = tf.get_collection(
                            tf.GraphKeys.TRAINABLE_VARIABLES, scope="PDetNet")
    grads = average_gradients(tower_grads)
    with tf.control_dependencies(update_op):
        #train_op = optimizer.minimize(loss, global_step=global_step, var_list=vars_det)
        apply_gradient_op = optimizer.apply_gradients(grads,
                                                      global_step=global_step)
        train_op = tf.group(apply_gradient_op, *update_op)

    # GPU config
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    ##pruning add by lzlu
    # Parse pruning hyperparameters
    pruning_hparams = pruning.get_pruning_hparams().parse(
        cfg.prune.pruning_hparams)

    # Create a pruning object using the pruning hyperparameters
    pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

    # Use the pruning_obj to add ops to the training graph to update the masks
    # The conditional_mask_update_op will update the masks only when the
    # training step is in [begin_pruning_step, end_pruning_step] specified in
    # the pruning spec proto
    mask_update_op = pruning_obj.conditional_mask_update_op()

    # Use the pruning_obj to add summaries to the graph to track the sparsity
    # of each of the layers
    pruning_summaries = pruning_obj.add_pruning_summaries()

    summaries |= pruning_summaries
    for summ in summaries_buf:
        summaries |= summ

    summaries.add(tf.summary.scalar('lr', lr))

    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    if cfg.summary.summary_allowed:
        summary_writer = tf.summary.FileWriter(
            logdir=cfg.summary.logs_path,
            graph=sess.graph,
            flush_secs=cfg.summary.summary_secs)

    # Create a saver
    saver = tf.train.Saver()
    ckpt_dir = re.sub(r'examples/', '', cfg.ckpt_path_608)

    if cfg.train.fine_tune == 0:
        # init
        sess.run(tf.global_variables_initializer())
    else:
        saver.restore(sess, cfg.train.rstd_path)

    # running
    for i in range(0, cfg.train.max_batches):
        _, loss_, gstep, sval, _ = sess.run(
            [train_op, current_loss, global_step, summary_op, mask_update_op])
        if (i % 100 == 0):
            print(i, ': ', loss_)
        if i % 1000 == 0 and i < 10000:
            saver.save(sess,
                       ckpt_dir + str(i) + '_plate.ckpt',
                       global_step=global_step,
                       write_meta_graph=False)
        if i % 10000 == 0:
            saver.save(sess,
                       ckpt_dir + str(i) + '_plate.ckpt',
                       global_step=global_step,
                       write_meta_graph=False)
        if cfg.summary.summary_allowed and gstep % cfg.summary.summ_steps == 0:
            summary_writer.add_summary(sval, global_step=gstep)
示例#20
0
def train_with_pruning():
    tf.compat.v1.reset_default_graph()

    # Inference
    network = Network(NUM_CLASSES)
    inputs = tf.compat.v1.placeholder(tf.float32, [None, INPUT_SIZE, INPUT_SIZE, INPUT_CHANNEL], 'inputs')
    logits = network.pruning_inference(inputs)

    # loss & accuracy
    labels = tf.compat.v1.placeholder(tf.int64, [None, ], 'labels')
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
    prediction = tf.argmax(tf.nn.softmax(logits), axis=1)
    acc = tf.reduce_mean(tf.cast(tf.equal(prediction, labels), dtype=tf.float32))

    # Create pruning operator
    global_step = tf.train.get_or_create_global_step()
    pruning_hparams = pruning.get_pruning_hparams()
    pruning_hparams.sparsity_function_end_step = 1000
    p = pruning.Pruning(pruning_hparams, global_step=global_step)
    mask_update_op = p.conditional_mask_update_op()
    p.add_pruning_summaries()

    # optimizer
    optimizer = tf.compat.v1.train.MomentumOptimizer(learning_rate=LEARNING_RATE, momentum=0.9)
    train_op = optimizer.minimize(loss, global_step)

    # loading data
    train_next = load_tfrecords('train')
    test_next = load_tfrecords('test')

    with tf.compat.v1.Session() as sess:
        sess.run(tf.compat.v1.global_variables_initializer())

        # summaries
        logs_dir = './logs/with_pruning'
        if not os.path.exists(logs_dir):
            os.makedirs(logs_dir)
        tf.compat.v1.summary.scalar('monitor/loss', loss)
        tf.compat.v1.summary.scalar('monitor/acc', acc)
        merged_summary_op = tf.compat.v1.summary.merge_all()
        train_summary_writer = tf.compat.v1.summary.FileWriter(os.path.join(logs_dir, 'train'), graph=sess.graph)
        test_summary_writer = tf.compat.v1.summary.FileWriter(os.path.join(logs_dir, 'test'), graph=sess.graph)

        best_acc = 0
        saver = tf.compat.v1.train.Saver()
        for epoch in range(NUM_EPOCHS):
            # training
            num_steps = TRAIN_SIZE // BATCH_SIZE
            train_acc = 0
            train_loss = 0
            for step in range(num_steps):
                x, y = sess.run(train_next)
                _, summary, train_acc_batch, train_loss_batch = sess.run([train_op, merged_summary_op, acc, loss],
                                                                         feed_dict={inputs: x, labels: y})
                sess.run(mask_update_op)
                train_acc += train_acc_batch
                train_loss += train_loss_batch
                sys.stdout.write("\r epoch %d, step %d, training accuracy %g, training loss %g" %
                                 (epoch + 1, step + 1, train_acc_batch, train_loss_batch))
                sys.stdout.flush()
                train_summary_writer.add_summary(summary, global_step=epoch * num_steps + step)
                train_summary_writer.flush()
            print("\n epoch %d, training accuracy %g, training loss %g" %
                  (epoch + 1, train_acc / num_steps, train_loss / num_steps))

            # testing
            num_steps = TEST_SIZE // BATCH_SIZE
            test_acc = 0
            test_loss = 0
            for step in range(num_steps):
                x, y = sess.run(test_next)
                summary, test_acc_batch, test_loss_batch = sess.run([merged_summary_op, acc, loss],
                                                                    feed_dict={inputs: x, labels: y})
                test_acc += test_acc_batch
                test_loss += test_loss_batch
                test_summary_writer.add_summary(summary, global_step=(epoch * num_steps + step) * (TRAIN_SIZE // TEST_SIZE))
                test_summary_writer.flush()
            print(" epoch %d, testing accuracy %g, testing loss %g" %
                  (epoch + 1, test_acc / num_steps, test_loss / num_steps))

            if test_acc / num_steps > best_acc:
                best_acc = test_acc / num_steps
                saver.save(sess, './ckpt_with_pruning/model')

        print(" Best Testing Accuracy %g" % best_acc)
示例#21
0
文件: train.py 项目: WaugZ/my_dgcnn
def train():
    with tf.Graph().as_default():
        with tf.device('/gpu:' + str(GPU_INDEX)):
            pointclouds_pl = MODEL.placeholder_input(BATCH_SIZE, NUM_POINT)
            labels_pl = MODEL.placeholder_label(BATCH_SIZE)
            if not FLAGS.quantize_delay:
                is_training = tf.placeholder(tf.bool, shape=(), name="is_training")
            else:
                is_training = True

            # Note the global_step=batch parameter to minimize.
            # That tells the optimizer to helpfully increment the 'batch' parameter for you every time it trains.
            batch = tf.Variable(0)
            # bn_decay = BN_INIT_DECAY
            bn_decay = get_bn_decay(batch)
            tf.summary.scalar('bn_decay', bn_decay)

            # Get model
            pred, end_points = MODEL.get_network(pointclouds_pl, is_training,
                                                 bn_decay=bn_decay,
                                                 dynamic=DYNAMIC,
                                                 STN=STN,
                                                 scale=SCALE,
                                                 concat_fea=CONCAT)

            # Parse pruning hyperparameters
            pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)

            # Create a pruning object using the pruning specification
            p = pruning.Pruning(pruning_hparams, global_step=batch)

            # Add conditional mask update op. Executing this op will update all
            # the masks in the graph if the current global step is in the range
            # [begin_pruning_step, end_pruning_step] as specified by the pruning spec
            mask_update_op = p.conditional_mask_update_op()

            # Add summaries to keep track of the sparsity in different layers during training
            p.add_pruning_summaries()

            if FLAGS.quantize_delay and FLAGS.quantize_delay > 0:
                quant_scopes = ["DGCNN/get_edge_feature", "DGCNN/get_edge_feature_1", "DGCNN/get_edge_feature_2",
                                "DGCNN/get_edge_feature_3", "DGCNN/get_edge_feature_4", "DGCNN/agg",
                                "DGCNN/transform_net", "DGCNN/Transform", "DGCNN/dgcnn1", "DGCNN/dgcnn2",
                                "DGCNN/dgcnn3", "DGCNN/dgcnn4",
                                "PointNet"]
                tf.contrib.quantize.create_training_graph(
                    quant_delay=FLAGS.quantize_delay)
                for scope in quant_scopes:
                    my_quantization.experimental_create_training_graph(quant_delay=FLAGS.quantize_delay,
                                                                       scope=scope)

            # Get loss
            loss = MODEL.get_loss(pred, labels_pl, end_points)
            regularization_losses = tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES)
            all_losses = []
            all_losses.append(loss)
            all_losses.append(tf.add_n(regularization_losses))
            total_loss = tf.add_n(all_losses)

            # tf.summary.scalar('loss', loss)
            tf.summary.scalar('loss', total_loss)

            correct = tf.equal(tf.argmax(pred, 1), tf.cast(labels_pl, tf.int64))
            accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE)
            tf.summary.scalar('accuracy', accuracy)

            # if update_ops:
            #     print("BN parameters: ", update_ops)
            #     updates = tf.group(*update_ops)
            #     train_step = control_flow_ops.with_dependencies([updates], batch)

            # Get training operator
            learning_rate = get_learning_rate(batch)
            tf.summary.scalar('learning_rate', learning_rate)
            if OPTIMIZER == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=MOMENTUM)
            elif OPTIMIZER == 'adam':
                optimizer = tf.train.AdamOptimizer(learning_rate)

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies([tf.group(*update_ops)]):
                train_op = optimizer.minimize(total_loss, global_step=batch)
                # train_op = slim.learning.create_train_op(total_loss, optimizer)

            # Add ops to save and restore all the variables.
            saver = tf.train.Saver(max_to_keep=51)

        # Create a session
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        sess = tf.Session(config=config)

        # Add summary writers
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'),
                                             sess.graph)
        test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test'))

        # Init variables
        init = tf.global_variables_initializer()
        # To fix the bug introduced in TF 0.12.1 as in
        # http://stackoverflow.com/questions/41543774/invalidargumenterror-for-tensor-bool-tensorflow-0-12-1
        sess.run(init)
        # sess.run(init, {is_training_pl: True})
        if FLAGS.quantize_delay and FLAGS.quantize_delay > 0:
            ops = {'pointclouds_pl': pointclouds_pl,
                   'labels_pl': labels_pl,
                   # 'is_training_pl': is_training,
                   'pred': pred,
                   'loss': loss,
                   'train_op': train_op,
                   'merged': merged,
                   'step': batch,
                   # 'mask_update_op': mask_update_op
                   }
        else:
            ops = {'pointclouds_pl': pointclouds_pl,
                   'labels_pl': labels_pl,
                   'is_training_pl': is_training,
                   'pred': pred,
                   'loss': loss,
                   'train_op': train_op,
                   'merged': merged,
                   'step': batch,
                   # 'mask_update_op': mask_update_op
                   }

        ever_best = 0
        if CHECKPOINT:
            saver.restore(sess, CHECKPOINT)
        for epoch in range(MAX_EPOCH):
            log_string(('**** EPOCH %03d ****' % (epoch))
                       + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + '****')
            sys.stdout.flush()

            ma = train_one_epoch(sess, ops, train_writer)
            if not FLAGS.quantize_delay:
                ma = eval_one_epoch(sess, ops, test_writer)

                # Save the variables to disk.

                if ma > ever_best:
                    save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"))
                    log_string("Model saved in file: %s" % save_path)
                    ever_best = ma
                log_string("Current model mean accuracy: {}".format(ma))
                log_string("Best model mean accuracy: {}".format(ever_best))
            else:
                if epoch % 5 == 0:
                    if CHECKPOINT:
                        save_path = saver.save(sess, os.path.join(LOG_DIR, "model-r-{}.ckpt".format(str(epoch))))
                    else:
                        save_path = saver.save(sess, os.path.join(LOG_DIR, "model-{}.ckpt".format(str(epoch))))
                    log_string("Model saved in file: %s" % save_path)
示例#22
0
    def __init__(self,
                 hparams=None,
                 mode='',
                 seed=None,
                 init_weight=0.01,
                 dtype=tf.float32):
        self.encoder_input_data = tf.placeholder(tf.int32, [None, None],
                                                 name='encoder_input_data')
        self.decoder_output_data = tf.placeholder(tf.int32, [None, None],
                                                  name='decoder_output_data')

        self.tgt_vocab_size = hparams.tgt_vocab_size

        len_temp = tf.sign(tf.add(tf.abs(tf.sign(self.encoder_input_data)), 1))
        self.seq_length_encoder_input_data = tf.cast(
            tf.reduce_sum(len_temp, -1), tf.int32)
        self.batch_size = tf.size(self.seq_length_encoder_input_data)

        self.num_layers = hparams.num_layers
        self.decoder_layer_num_more = hparams.decoder_layer_num_more

        self.unit_type = hparams.unit_type
        self.num_units = hparams.num_units
        self.dropout = hparams.dropout
        self.forget_bias = hparams.forget_bias

        self.attention_mode = hparams.attention_mode

        self.time_major = hparams.time_major
        self.residual = hparams.residual
        self.train_type = hparams.train_type
        self.mode = mode
        #l2 loss relate
        self.use_l2_loss = hparams.use_l2_loss
        self.l2_rate = hparams.l2_rate

        self.embedding_size = hparams.embedding_size
        self.dtype = dtype
        tf.get_variable_scope().set_initializer(
            tf.contrib.keras.initializers.glorot_normal(seed=None))
        self.global_step = tf.train.get_or_create_global_step()
        #pruing paramter
        if _hm.NEED_PRUNING:
            pruning_hparams = pruning.get_pruning_hparams().parse(
                _hm.PRUNING_PARAMS)
            pruning_obj = pruning.Pruning(pruning_hparams,
                                          global_step=self.global_step)
        #embeding variable
        with tf.variable_scope('embedding_var') as scope:
            shape = [hparams.src_vocab_size, hparams.embedding_size]
            self.embedding_encoder = tf.Variable(tf.random_uniform(
                shape, -0.01, 0.01),
                                                 dtype=tf.float32,
                                                 name="embedding")
            self.embedding_decoder = self.embedding_encoder

        self.crf_transmit = tf.get_variable(
            "crf_transmit", [self.tgt_vocab_size, self.tgt_vocab_size],
            initializer=tf.random_normal_initializer(0., 512**-0.5))
        res = self._build_graph()

        if (self.mode == _hm.MODE_TRAIN):
            self.loss = res[1]
            self.update, self.learning_rate = _mb.optimizer(
                hparams, self.loss, self.global_step)
            if _hm.NEED_PRUNING:
                self.mask_update_op = pruning_obj.conditional_mask_update_op()
                pruning_obj.add_pruning_summaries()
        #infer here
        else:
            logits = res[0]
            viterbi_sequence,_ =tf.contrib.crf.crf_decode(logits,\
                                                    self.crf_transmit,\
                                                    self.seq_length_encoder_input_data)
            self.neroutput = tf.identity(viterbi_sequence, name="NER_output")

        self.saver = tf.train.Saver(tf.global_variables(),
                                    max_to_keep=hparams.saver_max_time)
        self.merged_summary = tf.summary.merge_all()
示例#23
0
def train_function(pruning_method, loss, output_dir, use_tpu):
    """Training script for resnet model.

  Args:
   pruning_method: string indicating pruning method used to compress model.
   loss: tensor float32 of the cross entropy + regularization losses.
   output_dir: string tensor indicating the directory to save summaries.
   use_tpu: boolean indicating whether to run script on a tpu.

  Returns:
    host_call: summary tensors to be computed at each training step.
    train_op: the optimization term.
  """

    global_step = tf.train.get_global_step()

    steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
    current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch)
    learning_rate = lr_schedule(current_epoch)
    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                           momentum=FLAGS.momentum,
                                           use_nesterov=True)

    if use_tpu:
        # use CrossShardOptimizer when using TPU.
        optimizer = contrib_tpu.CrossShardOptimizer(optimizer)

    # UPDATE_OPS needs to be added as a dependency due to batch norm
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops), tf.name_scope('train'):
        train_op = optimizer.minimize(loss, global_step)

    if not use_tpu:
        if FLAGS.num_workers > 0:
            optimizer = tf.train.SyncReplicasOptimizer(
                optimizer,
                replicas_to_aggregate=FLAGS.num_workers,
                total_num_replicas=FLAGS.num_workers)
            optimizer.make_session_run_hook(True)

    metrics = {
        'global_step': tf.train.get_or_create_global_step(),
        'loss': loss,
        'learning_rate': learning_rate,
        'current_epoch': current_epoch
    }

    if pruning_method == 'threshold':
        # construct the necessary hparams string from the FLAGS
        hparams_string = ('begin_pruning_step={0},'
                          'sparsity_function_begin_step={0},'
                          'end_pruning_step={1},'
                          'sparsity_function_end_step={1},'
                          'target_sparsity={2},'
                          'pruning_frequency={3},'
                          'threshold_decay=0,'
                          'use_tpu={4}'.format(
                              FLAGS.sparsity_begin_step,
                              FLAGS.sparsity_end_step,
                              FLAGS.end_sparsity,
                              FLAGS.pruning_frequency,
                              FLAGS.use_tpu,
                          ))

        # Parse pruning hyperparameters
        pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)

        # The first layer has so few parameters, we don't need to prune it, and
        # pruning it a higher sparsity levels has very negative effects.
        if FLAGS.prune_first_layer and FLAGS.first_layer_sparsity >= 0.:
            pruning_hparams.set_hparam(
                'weight_sparsity_map',
                ['resnet_model/initial_conv:%f' % FLAGS.first_layer_sparsity])
        if FLAGS.prune_last_layer and FLAGS.last_layer_sparsity >= 0:
            pruning_hparams.set_hparam(
                'weight_sparsity_map',
                ['resnet_model/final_dense:%f' % FLAGS.last_layer_sparsity])

        # Create a pruning object using the pruning hyperparameters
        pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

        # We override the train op to also update the mask.
        with tf.control_dependencies([train_op]):
            train_op = pruning_obj.conditional_mask_update_op()

        masks = pruning.get_masks()
        metrics.update(utils.mask_summaries(masks))
    elif pruning_method == 'scratch':
        masks = pruning.get_masks()
        # make sure the masks have the sparsity we expect and that it doesn't change
        metrics.update(utils.mask_summaries(masks))
    elif pruning_method == 'variational_dropout':
        masks = utils.add_vd_pruning_summaries(
            threshold=FLAGS.log_alpha_threshold)
        metrics.update(masks)
    elif pruning_method == 'l0_regularization':
        summaries = utils.add_l0_summaries()
        metrics.update(summaries)
    elif pruning_method == 'baseline':
        pass
    else:
        raise ValueError('Unsupported pruning method', FLAGS.pruning_method)

    host_call = (functools.partial(utils.host_call_fn,
                                   output_dir), utils.format_tensors(metrics))

    return host_call, train_op
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        if FLAGS.quantize_delay >= 0:
            tf.contrib.quantize.create_training_graph(
                quant_delay=FLAGS.quantize_delay)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################
        if FLAGS.pruning:
            global mask_update_op
            if FLAGS.pruning_hparams is not None:
                pruning_hparams = pruning.get_pruning_hparams().parse(
                    FLAGS.pruning_hparams)
                pruning_obj = pruning.Pruning(pruning_hparams)
            else:
                pruning_obj = pruning.Pruning()
            pruning_obj.print_hparams()
            mask_update_op = pruning_obj.conditional_mask_update_op()

            slim.learning.train(
                train_tensor,
                logdir=FLAGS.train_dir,
                train_step_fn=train_step_with_pruning_fn,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                init_fn=_get_init_fn(),
                summary_op=summary_op,
                number_of_steps=FLAGS.max_number_of_steps,
                log_every_n_steps=FLAGS.log_every_n_steps,
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                sync_optimizer=optimizer if FLAGS.sync_replicas else None)

        else:
            slim.learning.train(
                train_tensor,
                logdir=FLAGS.train_dir,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                init_fn=_get_init_fn(),
                summary_op=summary_op,
                number_of_steps=FLAGS.max_number_of_steps,
                log_every_n_steps=FLAGS.log_every_n_steps,
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                sync_optimizer=optimizer if FLAGS.sync_replicas else None)
    def model_bulid(self, height, width, channel, classes):
        x = tf.placeholder(dtype=tf.float32,
                           shape=[None, height, width, channel])
        y = tf.placeholder(dtype=tf.float32, shape=[None, classes])

        # conv 1 ,if image Nx465x128x1 ,(conv 5x5 32 ,pool/2)
        conv1_1 = tf.nn.relu(
            self.conv_layer(x,
                            ksize=[5, 5, channel, 32],
                            stride=[1, 1, 1, 1],
                            padding="SAME",
                            name="conv1_1"))  # Nx465x128x1 ==>   Nx465x128x32
        pool1_1 = self.pool_layer(conv1_1,
                                  ksize=[1, 2, 2, 1],
                                  stride=[1, 2, 2, 1],
                                  name="pool1_1")  # N*232x64x32

        # conv 2,(conv 5x5 32)=>(conv 5x5 64, pool/2)
        conv2_1 = tf.nn.relu(
            self.conv_layer(pool1_1,
                            ksize=[5, 5, 32, 64],
                            stride=[1, 1, 1, 1],
                            padding="SAME",
                            name="conv2_1"))
        pool2_1 = self.pool_layer(conv2_1,
                                  ksize=[1, 2, 2, 1],
                                  stride=[1, 2, 2, 1],
                                  name="pool2_1")  # Nx116x32x128

        # Flatten
        ft = self.flatten(pool2_1)

        # Dense layer,(fc 100)=>=>(fc classes) and prune optimize
        fc_layer1 = layers.masked_fully_connected(ft, 200)
        fc_layer2 = layers.masked_fully_connected(fc_layer1, 100)
        prediction = layers.masked_fully_connected(fc_layer2, 10)

        loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits_v2(logits=prediction,
                                                       labels=y))
        #  original Dense layer
        # fc1 = self.fc_layer(ft,fc_dims=100,name="fc1")
        # finaloutput = self.finlaout_layer(fc1,fc_dims=10,name="final")

        #  pruning op
        global_step = tf.train.get_or_create_global_step()
        reset_global_step_op = tf.assign(global_step, 0)
        # Get, Print, and Edit Pruning Hyperparameters
        pruning_hparams = pruning.get_pruning_hparams()
        print("Pruning Hyper parameters:", pruning_hparams)
        # Change hyperparameters to meet our needs
        pruning_hparams.begin_pruning_step = 0
        pruning_hparams.end_pruning_step = 250
        pruning_hparams.pruning_frequency = 1
        pruning_hparams.sparsity_function_end_step = 250
        pruning_hparams.target_sparsity = .9
        # Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam
        p = pruning.Pruning(pruning_hparams, global_step=global_step)
        prune_op = p.conditional_mask_update_op()

        # optimize
        LEARNING_RATE_BASE = 0.001
        LEARNING_RATE_DECAY = 0.9
        LEARNING_RATE_STEP = 300
        gloabl_steps = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,
                                                   gloabl_steps,
                                                   LEARNING_RATE_STEP,
                                                   LEARNING_RATE_DECAY,
                                                   staircase=True)
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            optimize = tf.train.AdamOptimizer(
                learning_rate=learning_rate).minimize(loss, global_step)

        # prediction
        prediction_label = prediction
        correct_prediction = tf.equal(tf.argmax(prediction_label, 1),
                                      tf.argmax(y, 1))
        accurary = tf.reduce_mean(tf.cast(correct_prediction,
                                          dtype=tf.float32))
        correct_times_in_batch = tf.reduce_mean(
            tf.cast(correct_prediction, dtype=tf.int32))

        return dict(x=x,
                    y=y,
                    optimize=optimize,
                    correct_prediction=prediction_label,
                    correct_times_in_batch=correct_times_in_batch,
                    cost=loss,
                    accurary=accurary,
                    prune_op=prune_op)
示例#26
0
    def build_graph(self, hparams, scope=None):
        """Subclass must implement this method.

        Creates a sequence-to-sequence model with dynamic RNN decoder API.
        Args:
          hparams: Hyperparameter configurations.
          scope: VariableScope for the created subgraph; default "dynamic_seq2seq".

        Returns:
          A tuple of the form (logits, loss_tuple, final_context_state, sample_id),
          where:
            logits: float32 Tensor [batch_size x num_decoder_symbols].
            loss: loss = the total loss / batch_size.
            final_context_state: the final state of decoder RNN.
            sample_id: sampling indices.

        Raises:
          ValueError: if encoder_type differs from mono and bi, or
            attention_option is not (luong | scaled_luong |
            bahdanau | normed_bahdanau).
        """
        utils.print_out("# Creating %s graph ..." % self.mode)

        # Projection
        if not self.extract_encoder_layers:
            with tf.variable_scope(scope or "build_network"):
                with tf.variable_scope("decoder/output_projection"):
                    if hparams.projection_type == 'sparse':
                        self.output_layer = core_layers.MaskedFullyConnected(
                            hparams.tgt_vocab_size,
                            use_bias=False,
                            name="output_projection")
                    elif hparams.projection_type == 'dense':
                        self.output_layer = tf.layers.Dense(
                            hparams.tgt_vocab_size,
                            use_bias=False,
                            name="output_projection")
                    else:
                        raise ValueError("Unknown projection type %s!" %
                                         hparams.projection_type)

        with tf.variable_scope(scope or "dynamic_seq2seq", dtype=self.dtype):
            # Encoder
            if hparams.language_model:  # no encoder for language modeling
                utils.print_out("  language modeling: no encoder")
                self.encoder_outputs = None
                encoder_state = None
            else:
                self.encoder_outputs, encoder_state = self._build_encoder(
                    hparams)

            # Skip decoder if extracting only encoder layers
            if self.extract_encoder_layers:
                return

            # Decoder
            logits, decoder_cell_outputs, sample_id, final_context_state = (
                self._build_decoder(self.encoder_outputs, encoder_state,
                                    hparams))

            # Loss
            if self.mode != tf.contrib.learn.ModeKeys.INFER:
                with tf.device(
                        model_helper.get_device_str(
                            self.num_encoder_layers - 1, self.num_gpus)):
                    loss = self._compute_loss(logits, decoder_cell_outputs)
            else:
                loss = tf.constant(0.0)

            # model pruning
            if hparams.pruning_hparams is not None:
                pruning_hparams = pruning.get_pruning_hparams().parse(
                    hparams.pruning_hparams)
                self.p = pruning.Pruning(pruning_hparams,
                                         global_step=self.global_step)
                self.mask_update_op = self.p.conditional_mask_update_op()
                masks = get_masks()
                thresholds = get_thresholds()
                masks_s = []
                for index, mask in enumerate(masks):
                    masks_s.append(
                        tf.summary.scalar(mask.name + '/sparsity',
                                          tf.nn.zero_fraction(mask)))
                    masks_s.append(
                        tf.summary.scalar(
                            thresholds[index].op.name + '/threshold',
                            thresholds[index]))
                    masks_s.append(
                        tf.summary.histogram(mask.name + '/mask_tensor', mask))
                self.pruning_summary = tf.summary.merge([
                    tf.summary.scalar('sparsity', self.p._sparsity),
                    tf.summary.scalar('last_mask_update_step',
                                      self.p._last_update_step)
                ] + masks_s)
            else:
                self.mask_update_op = tf.no_op()
                self.pruning_summary = tf.no_op()

            return logits, loss, final_context_state, sample_id
    weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
    image_set = input_data.read_data_sets('~/tensor/AgeGenderDeepLearning-master/Folds/test-folds/gender_test_fold_is_3_DefaultRun', one_hot=True)

    #image = tf.placeholder(tf.float32, [None, 784])
    #label = tf.placeholder(tf.float32, [None, 10])

    layer1 = layers.masked_fully_connected(images, 512)
    layer2 = layers.masked_fully_connected(layer1, 512)
    logits = tf.nn.dropout(layer2, pkeep, name='drop1')

    batches = int(len(image_set.train.images) / batch_size)
    #loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=nlabels, logits=logits))

    with tf.variable_scope("Prune_Layer", "Prune_Layer", [images]) as scope:

        pruning_hparams = pruning.get_pruning_hparams()
        print("Pruning Hyperparameters:", pruning_hparams)

        # Change hyperparameters to meet our needs
        pruning_hparams.begin_pruning_step = 0
        pruning_hparams.end_pruning_step = 250
        pruning_hparams.pruning_frequency = 1
        pruning_hparams.sparsity_function_end_step = 250
        pruning_hparams.target_sparsity = .9
        global_step = tf.train.get_or_create_global_step()

        #train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss, global_step=global_step)
        reset_global_step_op = tf.assign(global_step, 0)
        
        p = pruning.Pruning(pruning_hparams, global_step=global_step, sparsity=.9)
        prune_op = p.conditional_mask_update_op()
示例#28
0
def train_alexnet(
        dataset_name='imagenet',
        prune=False,
        prune_params='',
        learning_rate=conf.learning_rate,
        num_epochs=conf.num_epochs,
        batch_size=conf.batch_size,
        learning_rate_decay_factor=conf.learning_rate_decay_factor,
        num_epochs_per_decay=conf.num_epochs_per_decay,
        dropout_rate=conf.dropout_rate,
        log_step=conf.log_step,
        checkpoint_step=conf.checkpoint_step,
        summary_path=conf.root_path + 'alexnet' + conf.summary_path,
        checkpoint_path=conf.root_path + 'alexnet' + conf.checkpoint_path,
        highest_accuracy_path=conf.root_path + 'alexnet' +
    conf.highest_accuracy_path,
        default_image_size=227,  #224 in the paper
):
    """prune_params: Comma separated list of pruning-related hyperparameters
       ex:'begin_pruning_step=10000,end_pruning_step=100000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000'
    """
    if dataset_name is 'imagenet':
        num_class = conf.imagenet['num_class']
        train_set_size = conf.imagenet['train_set_size']
        validation_set_size = conf.imagenet['validation_set_size']
        label_offset = conf.imagenet['label_offset']
        label_path = conf.imagenet['label_path']
        dataset_path = conf.imagenet['dataset_path']

        x = tf.placeholder(
            tf.float32,
            [batch_size, default_image_size, default_image_size, 3])
        y = tf.placeholder(tf.float32, [batch_size, num_class - label_offset])
        keep_prob = tf.placeholder(tf.float32)  #placeholder for dropout rate
        # prepare to train the model
        model = AlexNet.AlexNet(x,
                                keep_prob,
                                num_class - label_offset, [],
                                prune=prune)
        # Link variable to model output
        score = model.fc8

        # List of trainable variables of the layers we want to train
        var_list = [v for v in tf.trainable_variables()]

        # Op for calculating the loss
        with tf.name_scope("cross_ent"):
            loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(logits=score,
                                                        labels=y))

        global_step = tf.Variable(0, False)
        with tf.name_scope("train"):
            # Get gradients of all trainable variables
            decay_steps = int(train_set_size / batch_size *
                              num_epochs_per_decay)
            learning_rate = tf.train.exponential_decay(
                learning_rate,
                global_step,
                decay_steps,
                learning_rate_decay_factor,
                staircase=True)
            # Create optimizer and apply gradient descent to the trainable variables
            train_op = tf.train.GradientDescentOptimizer(
                learning_rate).minimize(loss, global_step)

        # Evaluation op: Accuracy of the model
        with tf.name_scope("accuracy"):
            correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1))
            accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

        if prune:
            # Parse pruning hyperparameters
            prune_params = pruning.get_pruning_hparams().parse(prune_params)
            # Create a pruning object using the pruning specification
            p = pruning.Pruning(prune_params, global_step=global_step)
            # Add conditional mask update op. Executing this op will update all
            # the masks in the graph if the current global step is in the range
            # [begin_pruning_step, end_pruning_step] as specified by the pruning spec
            mask_update_op = p.conditional_mask_update_op()
            # Add summaries to keep track of the sparsity in different layers during training
            p.add_pruning_summaries()

        # Add the variables we train to the summary
        for var in var_list:
            tf.summary.histogram(var.name, var)
        # Add the loss to summary
        tf.summary.scalar('cross_entropy', loss)
        # Add the accuracy to the summary
        tf.summary.scalar('accuracy', accuracy)
        # Merge all summaries together
        merged_summary = tf.summary.merge_all()
        # Initialize the FileWriter
        writer = tf.summary.FileWriter(summary_path)

        # prepare the data
        img_train, label_train, labels_text_train = read_tfrecord(
            'train', dataset_path, default_image_size=default_image_size)
        img_validation, label_validation, labels_text_validation = read_tfrecord(
            'validation', dataset_path, default_image_size=default_image_size)
        coord = tf.train.Coordinator()

        # Initialize an saver for store model checkpoints
        saver = tf.train.Saver()

        with tf.Session() as sess:

            # Initialize all variables
            sess.run(tf.global_variables_initializer())

            # Add the model graph to TensorBoard
            writer.add_graph(sess.graph)

            # Load the pretrained weights into the non-trainable layer
            model.load_initial_weights(sess)

            #start the input pipeline queue
            threads = tf.train.start_queue_runners(sess, coord=coord)

            # load the weights from checkpoint if there exists one
            model_saved = tf.train.get_checkpoint_state(checkpoint_path)
            if model_saved and model_saved.model_checkpoint_path:
                saver.restore(sess, model_saved.model_checkpoint_path)
                print('load model from ' + model_saved.model_checkpoint_path)

            print("{} Start training...".format(datetime.now()))
            print("{} Open Tensorboard at --logdir {}".format(
                datetime.now(), summary_path))

            # Loop over number of epochs
            for epoch in range(num_epochs):
                print("{} Epoch number: {}".format(datetime.now(), epoch + 1))

                highest_accuracy = 0  #highest accuracy by far
                if os.path.exists(highest_accuracy_path):
                    f = open(highest_accuracy_path, 'r')
                    highest_accuracy = float(f.read())
                    f.close()
                    print('highest accuracy from previous training is %f' %
                          highest_accuracy)

                train_batches_per_epoch = int(
                    np.floor(train_set_size / batch_size))
                for step in range(train_batches_per_epoch):
                    # train the model
                    img, l, l_text = sess.run(
                        [img_train, label_train, labels_text_train])
                    _, sc, gl_step, lr = sess.run(
                        [train_op, score, global_step, learning_rate],
                        feed_dict={
                            x: img,
                            y: l,
                            keep_prob: dropout_rate
                        })
                    if prune:
                        # Update the masks by running the mask_update_op
                        sess.run(mask_update_op)

                    # Generate summary with the current batch of data and write to file
                    if step % log_step == 0:
                        s, aq = sess.run([merged_summary, accuracy],
                                         feed_dict={
                                             x: img,
                                             y: l,
                                             keep_prob: 1.
                                         })
                        writer.add_summary(
                            s, epoch * train_batches_per_epoch + step)
                        print(
                            "global_step:" + str(gl_step) + ';learning_rate:' +
                            str(lr) + ';accuracy:', aq)

                    #validate the model and write checkpoint if the accuracy is higher
                    if step % checkpoint_step == 0 and step != 0:
                        val_batches_per_epoch = int(
                            np.floor(validation_set_size / batch_size))
                        print("{} Start validation".format(datetime.now()))
                        test_acc = 0.
                        test_count = 0
                        for _ in range(val_batches_per_epoch
                                       ):  # val_batches_per_epoch
                            #validate the model
                            img, l, l_text = sess.run([
                                img_validation, label_validation,
                                labels_text_validation
                            ])
                            acc = sess.run(accuracy,
                                           feed_dict={
                                               x: img,
                                               y: l,
                                               keep_prob: 1.
                                           })
                            test_acc += acc
                            test_count += 1
                        test_acc /= test_count
                        print("{} Validation Accuracy = {:.4f}".format(
                            datetime.now(), test_acc))
                        # save the model if it is better than the previous best model
                        if test_acc > highest_accuracy:
                            print("{} Saving checkpoint of model...".format(
                                datetime.now()))
                            highest_accuracy = test_acc
                            # save checkpoint of the model
                            checkpoint_name = os.path.join(
                                checkpoint_path, 'model_epoch' + '.ckpt')
                            # save_path = saver.save(sess, checkpoint_name, global_step=global_step)
                            f = open(highest_accuracy_path, 'w')
                            f.write(str(highest_accuracy))
                            f.close()
                            print("{} Model checkpoint saved at {}".format(
                                datetime.now(), checkpoint_name))
            coord.request_stop()
            coord.join(threads)
示例#29
0
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.contrib.framework.get_or_create_global_step()

    # Get images and labels for CIFAR-10.
    images, labels = cifar10.distorted_inputs()

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    # Parse pruning hyperparameters
    pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)

    # Create a pruning object using the pruning hyperparameters
    pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

    # Use the pruning_obj to add ops to the training graph to update the masks
    # The conditional_mask_update_op will update the masks only when the
    # training step is in [begin_pruning_step, end_pruning_step] specified in
    # the pruning spec proto
    mask_update_op = pruning_obj.conditional_mask_update_op()

    # Use the pruning_obj to add summaries to the graph to track the sparsity
    # of each of the layers
    pruning_obj.add_pruning_summaries()

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1

      def before_run(self, run_context):
        self._step += 1
        self._start_time = time.time()
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        duration = time.time() - self._start_time
        loss_value = run_values.results
        if self._step % 10 == 0:
          num_examples_per_step = 128
          examples_per_sec = num_examples_per_step / duration
          sec_per_batch = float(duration)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print(format_str % (datetime.datetime.now(), self._step, loss_value,
                              examples_per_sec, sec_per_batch))

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
      while not mon_sess.should_stop():
        mon_sess.run(train_op)
        # Update the masks
        mon_sess.run(mask_update_op)
iter = 0
finalPreds = np.empty((0, NUM_CLASSES))

### Removes old Tensorboard event files ###
allEventFiles = os.listdir('./logs/')
for file in allEventFiles:
    os.remove('./logs/' + file)

################ PRUNING #####################
PARAM_LIST = [
    "name=FFN_Pruning_Test", "pruning_frequency=10", "target_sparsity=0.5"
]
TEST_HPARAMS = ",".join(PARAM_LIST)

# Parse pruning hyperparameters
pruning_hparams = pruning.get_pruning_hparams().parse(TEST_HPARAMS)
#pruning_hparams = model_pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)

# Create a pruning object using the pruning specification
p = pruning.Pruning(pruning_hparams, global_step=global_step)

# Add conditional mask update op. Executing this op will update all
# the masks in the graph if the current global step is in the range
# [begin_pruning_step, end_pruning_step] as specified by the pruning spec
mask_update_op = p.conditional_mask_update_op()

# Add summaries to keep track of the sparsity in different layers during training
p.add_pruning_summaries()

### Data statistics ###
tic = time.time()
示例#31
0
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.contrib.framework.get_or_create_global_step()

    # Get images and labels for CIFAR-10.
    images, labels = cifar10.distorted_inputs()

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    # Parse pruning hyperparameters
    pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)

    # Create a pruning object using the pruning hyperparameters
    pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

    # Use the pruning_obj to add ops to the training graph to update the masks
    # The conditional_mask_update_op will update the masks only when the
    # training step is in [begin_pruning_step, end_pruning_step] specified in
    # the pruning spec proto
    mask_update_op = pruning_obj.conditional_mask_update_op()

    # Use the pruning_obj to add summaries to the graph to track the sparsity
    # of each of the layers
    pruning_obj.add_pruning_summaries()

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1

      def before_run(self, run_context):
        self._step += 1
        self._start_time = time.time()
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        duration = time.time() - self._start_time
        loss_value = run_values.results
        if self._step % 10 == 0:
          num_examples_per_step = 128
          examples_per_sec = num_examples_per_step / duration
          sec_per_batch = float(duration)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print(format_str % (datetime.datetime.now(), self._step, loss_value,
                              examples_per_sec, sec_per_batch))

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
      while not mon_sess.should_stop():
        mon_sess.run(train_op)
        # Update the masks
        mon_sess.run(mask_update_op)
def build_model():
    """Builds graph for model to train with rewrites for quantization.
  """
    g = tf.Graph()
    with g.as_default(), tf.device(
            tf.train.replica_device_setter(FLAGS.ps_tasks)):
        inputs, labels = hcl_input(is_training=True)
        #with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope(is_training=True)):
        logits, _ = mobilenet_v1_prune.mobilenet_v1(
            inputs,
            is_training=True,
            depth_multiplier=FLAGS.depth_multiplier,
            num_classes=FLAGS.num_classes)

        tf.losses.softmax_cross_entropy(labels, logits)

        # Call rewriter to produce graph with fake quant ops and folded batch norms
        # quant_delay delays start of quantization till quant_delay steps, allowing
        # for better model accuracy.
        if FLAGS.quantize:
            tf.contrib.quantize.create_training_graph(
                quant_delay=get_quant_delay())

        total_loss = tf.losses.get_total_loss(name='total_loss')
        # Configure the learning rate using an exponential decay.
        num_epochs_per_decay = 2.5
        hcl_size = 4650035  #3523535
        decay_steps = int(hcl_size / FLAGS.batch_size * num_epochs_per_decay)
        global_step = tf.train.get_or_create_global_step()
        learning_rate = tf.train.exponential_decay(
            get_learning_rate(),
            global_step,  #t1f.train.get_or_create_global_step(),
            decay_steps,
            _LEARNING_RATE_DECAY_FACTOR,
            staircase=True)
        opt = tf.train.GradientDescentOptimizer(learning_rate)

        # Get, Print, and Edit Pruning Hyperparameters
        pruning_hparams = pruning.get_pruning_hparams()
        #print("Pruning Hyperparameters:", pruning_hparams)

        # Change hyperparameters to meet our needs
        pruning_hparams.begin_pruning_step = 200000
        #pruning_hparams.end_pruning_step = 250
        #pruning_hparams.pruning_frequency = 1
        #pruning_hparams.sparsity_function_end_step = 250
        pruning_hparams.target_sparsity = .5
        print("Pruning Hyperparameters:", pruning_hparams)

        # Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam
        p = pruning.Pruning(pruning_hparams,
                            global_step=global_step,
                            sparsity=.5)
        prune_op = p.conditional_mask_update_op()

        train_tensor = slim.learning.create_train_op(total_loss, optimizer=opt)

    slim.summaries.add_scalar_summary(total_loss, 'total_loss', 'losses')
    slim.summaries.add_scalar_summary(learning_rate, 'learning_rate',
                                      'training')
    return g, [train_tensor, prune_op]
示例#33
0
def train_fn(training_method, global_step, total_loss, train_dir, accuracy,
             top_5_accuracy):
  """Training script for resnet model.

  Args:
   training_method: specifies the method used to sparsify networks.
   global_step: the current step of training/eval.
   total_loss: tensor float32 of the cross entropy + regularization losses.
   train_dir: string specifying where directory where summaries are saved.
   accuracy: tensor float32 batch classification accuracy.
   top_5_accuracy: tensor float32 batch classification accuracy (top_5 classes).

  Returns:
    hooks: summary tensors to be computed at each training step.
    eval_metrics: set to None during training.
    train_op: the optimization term.
  """
  # Rougly drops at every 30k steps.
  boundaries = [30000, 60000, 90000]
  if FLAGS.training_steps_multiplier != 1.0:
    multiplier = FLAGS.training_steps_multiplier
    boundaries = [int(x * multiplier) for x in boundaries]
    tf.logging.info(
        'Learning Rate boundaries are updated with multiplier:%.2f', multiplier)

  learning_rate = tf.train.piecewise_constant(
      global_step,
      boundaries,
      values=[0.1 / (5.**i) for i in range(len(boundaries) + 1)],
      name='lr_schedule')

  optimizer = tf.train.MomentumOptimizer(
      learning_rate, momentum=FLAGS.momentum, use_nesterov=True)

  if training_method == 'set':
    # We override the train op to also update the mask.
    optimizer = sparse_optimizers.SparseSETOptimizer(
        optimizer, begin_step=FLAGS.maskupdate_begin_step,
        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
        drop_fraction_anneal=FLAGS.drop_fraction_anneal)
  elif training_method == 'static':
    # We override the train op to also update the mask.
    optimizer = sparse_optimizers.SparseStaticOptimizer(
        optimizer, begin_step=FLAGS.maskupdate_begin_step,
        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
        drop_fraction_anneal=FLAGS.drop_fraction_anneal)
  elif training_method == 'momentum':
    # We override the train op to also update the mask.
    optimizer = sparse_optimizers.SparseMomentumOptimizer(
        optimizer, begin_step=FLAGS.maskupdate_begin_step,
        end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum,
        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
        grow_init=FLAGS.grow_init,
        drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=False)
  elif training_method == 'rigl':
    # We override the train op to also update the mask.
    optimizer = sparse_optimizers.SparseRigLOptimizer(
        optimizer, begin_step=FLAGS.maskupdate_begin_step,
        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
        frequency=FLAGS.maskupdate_frequency,
        drop_fraction=FLAGS.drop_fraction,
        drop_fraction_anneal=FLAGS.drop_fraction_anneal,
        initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=False)
  elif training_method == 'snip':
    optimizer = sparse_optimizers.SparseSnipOptimizer(
        optimizer, mask_init_method=FLAGS.mask_init_method,
        default_sparsity=FLAGS.end_sparsity, use_tpu=False)
  elif training_method in ('scratch', 'baseline', 'prune'):
    pass
  else:
    raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method)
  # Create the training op
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(total_loss, global_step)

  if training_method == 'prune':
    # construct the necessary hparams string from the FLAGS
    hparams_string = ('begin_pruning_step={0},'
                      'sparsity_function_begin_step={0},'
                      'end_pruning_step={1},'
                      'sparsity_function_end_step={1},'
                      'target_sparsity={2},'
                      'pruning_frequency={3},'
                      'threshold_decay=0,'
                      'use_tpu={4}'.format(
                          FLAGS.sparsity_begin_step,
                          FLAGS.sparsity_end_step,
                          FLAGS.end_sparsity,
                          FLAGS.pruning_frequency,
                          False,
                      ))
    # Parse pruning hyperparameters
    pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)

    # Create a pruning object using the pruning hyperparameters
    pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

    tf.logging.info('starting mask update op')

    # We override the train op to also update the mask.
    with tf.control_dependencies([train_op]):
      train_op = pruning_obj.conditional_mask_update_op()

  masks = pruning.get_masks()
  mask_metrics = utils.mask_summaries(masks)
  for name, tensor in mask_metrics.items():
    tf.summary.scalar(name, tensor)

  tf.summary.scalar('learning_rate', learning_rate)
  tf.summary.scalar('accuracy', accuracy)
  tf.summary.scalar('total_loss', total_loss)
  tf.summary.scalar('top_5_accuracy', top_5_accuracy)
  # Logging drop_fraction if dynamic sparse training.
  if training_method in ('set', 'momentum', 'rigl', 'static'):
    tf.summary.scalar('drop_fraction', optimizer.drop_fraction)

  summary_op = tf.summary.merge_all()
  summary_hook = tf.train.SummarySaverHook(
      save_secs=300, output_dir=train_dir, summary_op=summary_op)
  hooks = [summary_hook]
  eval_metrics = None

  return hooks, eval_metrics, train_op
示例#34
0
def main(unused_args):
    tf.set_random_seed(FLAGS.seed)
    tf.get_variable_scope().set_use_resource(True)
    np.random.seed(FLAGS.seed)

    # Load the MNIST data and set up an iterator.
    mnist_data = input_data.read_data_sets(FLAGS.mnist,
                                           one_hot=False,
                                           validation_size=0)
    train_images = mnist_data.train.images
    test_images = mnist_data.test.images
    if FLAGS.input_mask_path:
        reader = tf.train.load_checkpoint(FLAGS.input_mask_path)
        input_mask = reader.get_tensor('layer1/mask')
        indices = np.sum(input_mask, axis=1) != 0
        train_images = train_images[:, indices]
        test_images = test_images[:, indices]
    dataset = tf.data.Dataset.from_tensor_slices(
        (train_images, mnist_data.train.labels.astype(np.int32)))
    num_batches = mnist_data.train.images.shape[0] // FLAGS.batch_size
    dataset = dataset.shuffle(buffer_size=mnist_data.train.images.shape[0])
    batched_dataset = dataset.repeat(FLAGS.num_epochs).batch(FLAGS.batch_size)
    iterator = batched_dataset.make_one_shot_iterator()

    test_dataset = tf.data.Dataset.from_tensor_slices(
        (test_images, mnist_data.test.labels.astype(np.int32)))
    num_test_images = mnist_data.test.images.shape[0]
    test_dataset = test_dataset.repeat(FLAGS.num_epochs).batch(num_test_images)
    test_iterator = test_dataset.make_one_shot_iterator()

    # Set up loss function.
    use_model_pruning = FLAGS.training_method != 'baseline'

    if FLAGS.network_type == 'fc':
        cross_entropy_train, _ = mnist_network_fc(
            iterator.get_next(), model_pruning=use_model_pruning)
        cross_entropy_test, accuracy_test = mnist_network_fc(
            test_iterator.get_next(),
            reuse=True,
            model_pruning=use_model_pruning)
    else:
        raise RuntimeError(FLAGS.network + ' is an unknown network type.')

    # Remove extra added ones. Current implementation adds the variables twice
    # to the collection. Improve this hacky thing.
    # TODO test the following with the convnet or any other network.
    if use_model_pruning:
        for k in ('masks', 'masked_weights', 'thresholds', 'kernel'):
            # del tf.get_collection_ref(k)[2]
            # del tf.get_collection_ref(k)[2]
            collection = tf.get_collection_ref(k)
            del collection[len(collection) // 2:]
            print(tf.get_collection_ref(k))

    # Set up optimizer and update ops.
    global_step = tf.train.get_or_create_global_step()
    batch_per_epoch = mnist_data.train.images.shape[0] // FLAGS.batch_size

    if FLAGS.optimizer != 'adam':
        if not use_model_pruning:
            boundaries = [
                int(round(s * batch_per_epoch)) for s in [60, 70, 80]
            ]
        else:
            boundaries = [
                int(round(s * batch_per_epoch))
                for s in [FLAGS.lr_drop_epoch, FLAGS.lr_drop_epoch + 20]
            ]
        learning_rate = tf.train.piecewise_constant(
            global_step,
            boundaries,
            values=[
                FLAGS.learning_rate / (3.**i)
                for i in range(len(boundaries) + 1)
            ])
    else:
        learning_rate = FLAGS.learning_rate

    if FLAGS.optimizer == 'adam':
        opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
    elif FLAGS.optimizer == 'momentum':
        opt = tf.train.MomentumOptimizer(learning_rate,
                                         FLAGS.momentum,
                                         use_nesterov=FLAGS.use_nesterov)
    elif FLAGS.optimizer == 'sgd':
        opt = tf.train.GradientDescentOptimizer(learning_rate)
    else:
        raise RuntimeError(FLAGS.optimizer + ' is unknown optimizer type')
    custom_sparsities = {
        'layer2': FLAGS.end_sparsity * FLAGS.sparsity_scale,
        'layer3': FLAGS.end_sparsity * 0
    }

    if FLAGS.training_method == 'set':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseSETOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal)
    elif FLAGS.training_method == 'static':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseStaticOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal)
    elif FLAGS.training_method == 'momentum':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseMomentumOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            momentum=FLAGS.s_momentum,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            grow_init=FLAGS.grow_init,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            use_tpu=False)
    elif FLAGS.training_method == 'rigl':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseRigLOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            initial_acc_scale=FLAGS.rigl_acc_scale,
            use_tpu=False)
    elif FLAGS.training_method == 'snip':
        opt = sparse_optimizers.SparseSnipOptimizer(
            opt,
            mask_init_method=FLAGS.mask_init_method,
            default_sparsity=FLAGS.end_sparsity,
            custom_sparsity_map=custom_sparsities,
            use_tpu=False)
    elif FLAGS.training_method in ('scratch', 'baseline', 'prune'):
        pass
    else:
        raise ValueError('Unsupported pruning method: %s' %
                         FLAGS.training_method)

    train_op = opt.minimize(cross_entropy_train, global_step=global_step)

    if FLAGS.training_method == 'prune':
        hparams_string = (
            'begin_pruning_step={0},sparsity_function_begin_step={0},'
            'end_pruning_step={1},sparsity_function_end_step={1},'
            'target_sparsity={2},pruning_frequency={3},'
            'threshold_decay={4}'.format(FLAGS.prune_begin_step,
                                         FLAGS.prune_end_step,
                                         FLAGS.end_sparsity,
                                         FLAGS.pruning_frequency,
                                         FLAGS.threshold_decay))
        pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)
        pruning_hparams.set_hparam(
            'weight_sparsity_map',
            ['{0}:{1}'.format(k, v) for k, v in custom_sparsities.items()])
        print(pruning_hparams)
        pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)
        with tf.control_dependencies([train_op]):
            train_op = pruning_obj.conditional_mask_update_op()
    weight_sparsity_levels = pruning.get_weight_sparsity()
    global_sparsity = sparse_utils.calculate_sparsity(pruning.get_masks())
    tf.summary.scalar('test_accuracy', accuracy_test)
    tf.summary.scalar('global_sparsity', global_sparsity)
    for k, v in zip(pruning.get_masks(), weight_sparsity_levels):
        tf.summary.scalar('sparsity/%s' % k.name, v)
    if FLAGS.training_method in ('prune', 'snip', 'baseline'):
        mask_init_op = tf.no_op()
        tf.logging.info('No mask is set, starting dense.')
    else:
        all_masks = pruning.get_masks()
        mask_init_op = sparse_utils.get_mask_init_fn(all_masks,
                                                     FLAGS.mask_init_method,
                                                     FLAGS.end_sparsity,
                                                     custom_sparsities)

    if FLAGS.save_model:
        saver = tf.train.Saver()
    init_op = tf.global_variables_initializer()
    hyper_params_string = '_'.join([
        FLAGS.network_type,
        str(FLAGS.batch_size),
        str(FLAGS.learning_rate),
        str(FLAGS.momentum), FLAGS.optimizer,
        str(FLAGS.l2_scale), FLAGS.training_method,
        str(FLAGS.prune_begin_step),
        str(FLAGS.prune_end_step),
        str(FLAGS.end_sparsity),
        str(FLAGS.pruning_frequency),
        str(FLAGS.seed)
    ])
    tf.io.gfile.makedirs(FLAGS.save_path)
    filename = os.path.join(FLAGS.save_path, hyper_params_string + '.txt')
    merged_summary_op = tf.summary.merge_all()

    # Run session.
    if not use_model_pruning:
        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(
                FLAGS.save_path, graph=tf.get_default_graph())
            print('Epoch', 'Epoch time', 'Test loss', 'Test accuracy')
            sess.run([init_op])
            tic = time.time()
            with tf.io.gfile.GFile(filename, 'w') as outputfile:
                for i in range(FLAGS.num_epochs * num_batches):
                    sess.run([train_op])

                    if (i % num_batches) == (-1 % num_batches):
                        epoch_time = time.time() - tic
                        loss, accuracy, summary = sess.run([
                            cross_entropy_test, accuracy_test,
                            merged_summary_op
                        ])
                        # Write logs at every test iteration.
                        summary_writer.add_summary(summary, i)
                        log_str = '%d, %.4f, %.4f, %.4f' % (
                            i // num_batches, epoch_time, loss, accuracy)
                        print(log_str)
                        print(log_str, file=outputfile)
                        tic = time.time()
            if FLAGS.save_model:
                saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt'))
    else:
        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(
                FLAGS.save_path, graph=tf.get_default_graph())
            log_str = ','.join([
                'Epoch', 'Iteration', 'Test loss', 'Test accuracy',
                'G_Sparsity', 'Sparsity Layer 0', 'Sparsity Layer 1'
            ])
            sess.run(init_op)
            sess.run(mask_init_op)
            tic = time.time()
            mask_records = {}
            with tf.io.gfile.GFile(filename, 'w') as outputfile:
                print(log_str)
                print(log_str, file=outputfile)
                for i in range(FLAGS.num_epochs * num_batches):
                    if (FLAGS.mask_record_frequency > 0
                            and i % FLAGS.mask_record_frequency == 0):
                        mask_vals = sess.run(pruning.get_masks())
                        # Cast into bool to save space.
                        mask_records[i] = [
                            a.astype(np.bool) for a in mask_vals
                        ]
                    sess.run([train_op])
                    weight_sparsity, global_sparsity_val = sess.run(
                        [weight_sparsity_levels, global_sparsity])
                    if (i % num_batches) == (-1 % num_batches):
                        epoch_time = time.time() - tic
                        loss, accuracy, summary = sess.run([
                            cross_entropy_test, accuracy_test,
                            merged_summary_op
                        ])
                        # Write logs at every test iteration.
                        summary_writer.add_summary(summary, i)
                        log_str = '%d, %d, %.4f, %.4f, %.4f, %.4f, %.4f' % (
                            i // num_batches, i, loss, accuracy,
                            global_sparsity_val, weight_sparsity[0],
                            weight_sparsity[1])
                        print(log_str)
                        print(log_str, file=outputfile)
                        mask_vals = sess.run(pruning.get_masks())
                        if FLAGS.network_type == 'fc':
                            sparsities, sizes = get_compressed_fc(mask_vals)
                            print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' %
                                  (sparsities, sizes))
                            print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' %
                                  (sparsities, sizes),
                                  file=outputfile)
                        tic = time.time()
            if FLAGS.save_model:
                saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt'))
            if mask_records:
                np.save(os.path.join(FLAGS.save_path, 'mask_records'),
                        mask_records)
示例#35
0
def build_graph(reader,
                model,
                train_data_pattern,
                label_loss_fn=losses.CrossEntropyLoss(),
                batch_size=1000,
                base_learning_rate=0.01,
                learning_rate_decay_examples=1000000,
                learning_rate_decay=0.95,
                optimizer_class=tf.train.AdamOptimizer,
                clip_gradient_norm=1.0,
                regularization_penalty=1,
                num_readers=1,
                num_epochs=None):
    """Creates the Tensorflow graph.

  This will only be called once in the life of
  a training model, because after the graph is created the model will be
  restored from a meta graph file rather than being recreated.

  Args:
    reader: The data file reader. It should inherit from BaseReader.
    model: The core model (e.g. logistic or neural net). It should inherit
           from BaseModel.
    train_data_pattern: glob path to the training data files.
    label_loss_fn: What kind of loss to apply to the model. It should inherit
                from BaseLoss.
    batch_size: How many examples to process at a time.
    base_learning_rate: What learning rate to initialize the optimizer with.
    optimizer_class: Which optimization algorithm to use.
    clip_gradient_norm: Magnitude of the gradient to clip to.
    regularization_penalty: How much weight to give the regularization loss
                            compared to the label loss.
    num_readers: How many threads to use for I/O operations.
    num_epochs: How many passes to make over the data. 'None' means an
                unlimited number of passes.
  """

    global_step = tf.Variable(0, trainable=False, name="global_step")

    local_device_protos = device_lib.list_local_devices()
    gpus = [x.name for x in local_device_protos if x.device_type == 'GPU']
    gpus = gpus[:FLAGS.num_gpu]
    #gpus = gpus[-1:]
    num_gpus = len(gpus)

    if num_gpus > 0:
        logging.info("Using the following GPUs to train: " + str(gpus))
        num_towers = num_gpus
        device_string = '/gpu:%d'
    else:
        logging.info("No GPUs found. Training on CPU.")
        num_towers = 1
        device_string = '/cpu:%d'

    learning_rate = tf.train.exponential_decay(base_learning_rate,
                                               global_step * batch_size *
                                               num_towers,
                                               learning_rate_decay_examples,
                                               learning_rate_decay,
                                               staircase=True)
    tf.summary.scalar('learning_rate', learning_rate)

    optimizer = optimizer_class(learning_rate)
    unused_video_id, model_input_raw, labels_batch, num_frames = (
        get_input_data_tensors(reader,
                               train_data_pattern.split(','),
                               batch_size=batch_size * num_towers,
                               num_readers=num_readers,
                               num_epochs=num_epochs))
    tf.summary.histogram("model/input_raw", model_input_raw)
    feature_dim = len(model_input_raw.get_shape()) - 1

    model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)

    tower_inputs = tf.split(model_input, num_towers)
    tower_labels = tf.split(labels_batch, num_towers)
    tower_num_frames = tf.split(num_frames, num_towers)

    tower_gradients = []
    tower_predictions = []
    tower_label_losses = []
    tower_reg_losses = []
    for i in range(num_towers):
        with tf.device(device_string % i):
            with (tf.variable_scope(("tower"), reuse=True if i > 0 else None)):
                with (slim.arg_scope(
                    [slim.model_variable, slim.variable],
                        device="/cpu:0" if num_gpus != 1 else "/gpu:0")):
                    logging.info('building graph with ' + device_string % i)
                    result = model.create_model(tower_inputs[i],
                                                num_frames=tower_num_frames[i],
                                                vocab_size=reader.num_classes,
                                                labels=tower_labels[i])

                    for variable in slim.get_model_variables():
                        tf.summary.histogram(variable.op.name, variable)

                    predictions = result["predictions"]
                    tower_predictions.append(predictions)

                    if "loss" in result.keys():
                        label_loss = result["loss"]
                    else:
                        label_loss = label_loss_fn.calculate_loss(
                            predictions, tower_labels[i])

                    if "regularization_loss" in result.keys():
                        reg_loss = result["regularization_loss"]
                    else:
                        reg_loss = tf.constant(0.0)

                    reg_losses = tf.losses.get_regularization_losses()
                    if reg_losses:
                        reg_loss += tf.add_n(reg_losses)

                    tower_reg_losses.append(reg_loss)

                    # Adds update_ops (e.g., moving average updates in batch normalization) as
                    # a dependency to the train_op.
                    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                    if "update_ops" in result.keys():
                        update_ops += result["update_ops"]
                    if update_ops:
                        with tf.control_dependencies(update_ops):
                            barrier = tf.no_op(name="gradient_barrier")
                            with tf.control_dependencies([barrier]):
                                label_loss = tf.identity(label_loss)

                    tower_label_losses.append(label_loss)

                    # Incorporate the L2 weight penalties etc.
                    final_loss = regularization_penalty * reg_loss + label_loss
                    gradients = optimizer.compute_gradients(
                        final_loss, colocate_gradients_with_ops=False)
                    tower_gradients.append(gradients)
    label_loss = tf.reduce_mean(tf.stack(tower_label_losses))
    tf.summary.scalar("label_loss", label_loss)
    if regularization_penalty != 0:
        reg_loss = tf.reduce_mean(tf.stack(tower_reg_losses))
        tf.summary.scalar("reg_loss", reg_loss)
    merged_gradients = utils.combine_gradients(tower_gradients)

    if clip_gradient_norm > 0:
        with tf.name_scope('clip_grads'):
            merged_gradients = utils.clip_gradient_norms(
                merged_gradients, clip_gradient_norm)

    train_op = optimizer.apply_gradients(merged_gradients,
                                         global_step=global_step)

    pruning_hparams = pruning.get_pruning_hparams().parse(
        FLAGS.pruning_hparams)
    p = pruning.Pruning(pruning_hparams, global_step=global_step)
    mask_update_op = p.conditional_mask_update_op()
    tf.add_to_collection("global_step", global_step)
    tf.add_to_collection("loss", label_loss)
    tf.add_to_collection("predictions", tf.concat(tower_predictions, 0))
    tf.add_to_collection("input_batch_raw", model_input_raw)
    tf.add_to_collection("input_batch", model_input)
    tf.add_to_collection("num_frames", num_frames)
    tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32))
    tf.add_to_collection("train_op", train_op)
    tf.add_to_collection('mask_update_op', mask_update_op)