def testMaskedLSTMCell(self): expected_num_masks = 1 expected_num_rows = 2 * self.dim expected_num_cols = 4 * self.dim with self.test_session(): inputs = variables.Variable( random_ops.random_normal([self.batch_size, self.dim])) c = variables.Variable( random_ops.random_normal([self.batch_size, self.dim])) h = variables.Variable( random_ops.random_normal([self.batch_size, self.dim])) state = tf_rnn_cells.LSTMStateTuple(c, h) lstm_cell = rnn_cells.MaskedLSTMCell(self.dim) lstm_cell(inputs, state) self.assertEqual(len(pruning.get_masks()), expected_num_masks) self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks) self.assertEqual(len(pruning.get_thresholds()), expected_num_masks) self.assertEqual(len(pruning.get_weights()), expected_num_masks) for mask in pruning.get_masks(): self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols)) for weight in pruning.get_weights(): self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols))
def __init__(self, input_size, output_size, model_path: str, momentum=0.9, reg_str=0.0005, scope='ConvNet', pruning_start=int(10e4), pruning_end=int(10e5), pruning_freq=int(10), sparsity_start=0, sparsity_end=int(10e5), target_sparsity=0.0, dropout=0.5, initial_sparsity=0, wd=0.0): super(ConvNet, self).__init__(input_size=input_size, output_size=output_size, model_path=model_path) self.scope = scope self.momentum = momentum self.reg_str = reg_str self.dropout = dropout self.logger = get_logger(scope) self.wd = wd self.logger.info("creating graph...") with self.graph.as_default(): self.global_step = tf.Variable(0, trainable=False) self._build_placeholders() self.logits = self._build_model() self.weights_matrices = pruning.get_masked_weights() self.sparsity = pruning.get_weight_sparsity() self.loss = self._loss() self.train_op = self._optimizer() self._create_metrics() self.saver = tf.train.Saver(var_list=tf.global_variables()) self.hparams = pruning.get_pruning_hparams()\ .parse('name={}, begin_pruning_step={}, end_pruning_step={}, target_sparsity={},' ' sparsity_function_begin_step={},sparsity_function_end_step={},' 'pruning_frequency={},initial_sparsity={},' ' sparsity_function_exponent={}'.format(scope, pruning_start, pruning_end, target_sparsity, sparsity_start, sparsity_end, pruning_freq, initial_sparsity, 3)) # note that the global step plays an important part in the pruning mechanism, # the higher the global step the closer the sparsity is to sparsity end self.pruning_obj = pruning.Pruning(self.hparams, global_step=self.global_step) self.mask_update_op = self.pruning_obj.conditional_mask_update_op() # the pruning objects defines the pruning mechanism, via the mask_update_op the model gets pruned # the pruning takes place at each training epoch and it objective to achieve the sparsity end HP self.init_variables( tf.global_variables()) # initialize variables in graph
def __init__(self, actor_input_dim, actor_output_dim, model_path, redundancy=None, last_measure=10e4, tau=0.01): super(StudentActor, self).__init__(model_path=model_path) self.actor_input_dim = (None, actor_input_dim) self.actor_output_dim = (None, actor_output_dim) self.tau = tau self.redundancy = redundancy self.last_measure = last_measure with self.graph.as_default(): self.actor_global_step = tf.Variable(0, trainable=False) self._build_placeholders() self.actor_logits = self._build_actor() # self.gumbel_dist = self._build_gumbel(self.actor_logits) self.loss = self._build_loss() self.actor_parameters = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='actor') self.actor_pruned_weight_matrices = pruning.get_masked_weights() self.actor_train_op = self._build_actor_train_op() self.actor_saver = tf.train.Saver(var_list=self.actor_parameters, max_to_keep=100) self.init_variables(tf.global_variables()) self.sparsity = pruning.get_weight_sparsity() self.hparams = pruning.get_pruning_hparams() \ .parse('name={}, begin_pruning_step={}, end_pruning_step={}, target_sparsity={},' ' sparsity_function_begin_step={},sparsity_function_end_step={},' 'pruning_frequency={},initial_sparsity={},' ' sparsity_function_exponent={}'.format('Actor', cfg.pruning_start, cfg.pruning_end, cfg.target_sparsity, cfg.sparsity_start, cfg.sparsity_end, cfg.pruning_freq, cfg.initial_sparsity, 3)) # note that the global step plays an important part in the pruning mechanism, # the higher the global step the closer the sparsity is to sparsity end self.pruning_obj = pruning.Pruning( self.hparams, global_step=self.actor_global_step) self.mask_update_op = self.pruning_obj.conditional_mask_update_op() # the pruning objects defines the pruning mechanism, via the mask_update_op the model gets pruned # the pruning takes place at each training epoch and it objective to achieve the sparsity end HP self.init_variables( tf.global_variables()) # initialize variables in graph
def testMaskedLSTMCell(self): expected_num_masks = 1 expected_num_rows = 2 * self.dim expected_num_cols = 4 * self.dim with self.cached_session(): inputs = variables.Variable( random_ops.random_normal([self.batch_size, self.dim])) c = variables.Variable( random_ops.random_normal([self.batch_size, self.dim])) h = variables.Variable( random_ops.random_normal([self.batch_size, self.dim])) state = tf_rnn_cells.LSTMStateTuple(c, h) lstm_cell = rnn_cells.MaskedLSTMCell(self.dim) lstm_cell(inputs, state) self.assertEqual(len(pruning.get_masks()), expected_num_masks) self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks) self.assertEqual(len(pruning.get_thresholds()), expected_num_masks) self.assertEqual(len(pruning.get_weights()), expected_num_masks) for mask in pruning.get_masks(): self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols)) for weight in pruning.get_weights(): self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols))
def get_masked_weights(self): return pruning.get_masked_weights()