Example #1
0
    def testMaskedLSTMCell(self):
        expected_num_masks = 1
        expected_num_rows = 2 * self.dim
        expected_num_cols = 4 * self.dim
        with self.test_session():
            inputs = variables.Variable(
                random_ops.random_normal([self.batch_size, self.dim]))
            c = variables.Variable(
                random_ops.random_normal([self.batch_size, self.dim]))
            h = variables.Variable(
                random_ops.random_normal([self.batch_size, self.dim]))
            state = tf_rnn_cells.LSTMStateTuple(c, h)
            lstm_cell = rnn_cells.MaskedLSTMCell(self.dim)
            lstm_cell(inputs, state)
            self.assertEqual(len(pruning.get_masks()), expected_num_masks)
            self.assertEqual(len(pruning.get_masked_weights()),
                             expected_num_masks)
            self.assertEqual(len(pruning.get_thresholds()), expected_num_masks)
            self.assertEqual(len(pruning.get_weights()), expected_num_masks)

            for mask in pruning.get_masks():
                self.assertEqual(mask.shape,
                                 (expected_num_rows, expected_num_cols))
            for weight in pruning.get_weights():
                self.assertEqual(weight.shape,
                                 (expected_num_rows, expected_num_cols))
Example #2
0
 def __init__(self,
              input_size,
              output_size,
              model_path: str,
              momentum=0.9,
              reg_str=0.0005,
              scope='ConvNet',
              pruning_start=int(10e4),
              pruning_end=int(10e5),
              pruning_freq=int(10),
              sparsity_start=0,
              sparsity_end=int(10e5),
              target_sparsity=0.0,
              dropout=0.5,
              initial_sparsity=0,
              wd=0.0):
     super(ConvNet, self).__init__(input_size=input_size,
                                   output_size=output_size,
                                   model_path=model_path)
     self.scope = scope
     self.momentum = momentum
     self.reg_str = reg_str
     self.dropout = dropout
     self.logger = get_logger(scope)
     self.wd = wd
     self.logger.info("creating graph...")
     with self.graph.as_default():
         self.global_step = tf.Variable(0, trainable=False)
         self._build_placeholders()
         self.logits = self._build_model()
         self.weights_matrices = pruning.get_masked_weights()
         self.sparsity = pruning.get_weight_sparsity()
         self.loss = self._loss()
         self.train_op = self._optimizer()
         self._create_metrics()
         self.saver = tf.train.Saver(var_list=tf.global_variables())
         self.hparams = pruning.get_pruning_hparams()\
             .parse('name={}, begin_pruning_step={}, end_pruning_step={}, target_sparsity={},'
                    ' sparsity_function_begin_step={},sparsity_function_end_step={},'
                    'pruning_frequency={},initial_sparsity={},'
                    ' sparsity_function_exponent={}'.format(scope,
                                                            pruning_start,
                                                            pruning_end,
                                                            target_sparsity,
                                                            sparsity_start,
                                                            sparsity_end,
                                                            pruning_freq,
                                                            initial_sparsity,
                                                            3))
         # note that the global step plays an important part in the pruning mechanism,
         # the higher the global step the closer the sparsity is to sparsity end
         self.pruning_obj = pruning.Pruning(self.hparams,
                                            global_step=self.global_step)
         self.mask_update_op = self.pruning_obj.conditional_mask_update_op()
         # the pruning objects defines the pruning mechanism, via the mask_update_op the model gets pruned
         # the pruning takes place at each training epoch and it objective to achieve the sparsity end HP
         self.init_variables(
             tf.global_variables())  # initialize variables in graph
Example #3
0
 def __init__(self,
              actor_input_dim,
              actor_output_dim,
              model_path,
              redundancy=None,
              last_measure=10e4,
              tau=0.01):
     super(StudentActor, self).__init__(model_path=model_path)
     self.actor_input_dim = (None, actor_input_dim)
     self.actor_output_dim = (None, actor_output_dim)
     self.tau = tau
     self.redundancy = redundancy
     self.last_measure = last_measure
     with self.graph.as_default():
         self.actor_global_step = tf.Variable(0, trainable=False)
         self._build_placeholders()
         self.actor_logits = self._build_actor()
         # self.gumbel_dist = self._build_gumbel(self.actor_logits)
         self.loss = self._build_loss()
         self.actor_parameters = tf.get_collection(
             tf.GraphKeys.TRAINABLE_VARIABLES, scope='actor')
         self.actor_pruned_weight_matrices = pruning.get_masked_weights()
         self.actor_train_op = self._build_actor_train_op()
         self.actor_saver = tf.train.Saver(var_list=self.actor_parameters,
                                           max_to_keep=100)
         self.init_variables(tf.global_variables())
         self.sparsity = pruning.get_weight_sparsity()
         self.hparams = pruning.get_pruning_hparams() \
             .parse('name={}, begin_pruning_step={}, end_pruning_step={}, target_sparsity={},'
                    ' sparsity_function_begin_step={},sparsity_function_end_step={},'
                    'pruning_frequency={},initial_sparsity={},'
                    ' sparsity_function_exponent={}'.format('Actor',
                                                            cfg.pruning_start,
                                                            cfg.pruning_end,
                                                            cfg.target_sparsity,
                                                            cfg.sparsity_start,
                                                            cfg.sparsity_end,
                                                            cfg.pruning_freq,
                                                            cfg.initial_sparsity,
                                                            3))
         # note that the global step plays an important part in the pruning mechanism,
         # the higher the global step the closer the sparsity is to sparsity end
         self.pruning_obj = pruning.Pruning(
             self.hparams, global_step=self.actor_global_step)
         self.mask_update_op = self.pruning_obj.conditional_mask_update_op()
         # the pruning objects defines the pruning mechanism, via the mask_update_op the model gets pruned
         # the pruning takes place at each training epoch and it objective to achieve the sparsity end HP
         self.init_variables(
             tf.global_variables())  # initialize variables in graph
Example #4
0
  def testMaskedLSTMCell(self):
    expected_num_masks = 1
    expected_num_rows = 2 * self.dim
    expected_num_cols = 4 * self.dim
    with self.cached_session():
      inputs = variables.Variable(
          random_ops.random_normal([self.batch_size, self.dim]))
      c = variables.Variable(
          random_ops.random_normal([self.batch_size, self.dim]))
      h = variables.Variable(
          random_ops.random_normal([self.batch_size, self.dim]))
      state = tf_rnn_cells.LSTMStateTuple(c, h)
      lstm_cell = rnn_cells.MaskedLSTMCell(self.dim)
      lstm_cell(inputs, state)
      self.assertEqual(len(pruning.get_masks()), expected_num_masks)
      self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks)
      self.assertEqual(len(pruning.get_thresholds()), expected_num_masks)
      self.assertEqual(len(pruning.get_weights()), expected_num_masks)

      for mask in pruning.get_masks():
        self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols))
      for weight in pruning.get_weights():
        self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols))
Example #5
0
 def get_masked_weights(self):
     return pruning.get_masked_weights()