def filter_trainable_variables(trainable_scopes): """Keep only trainable variables which are prefixed with given scopes. Args: trainable_scopes: either list of trainable scopes or string with comma separated list of trainable scopes. This function removes all variables which are not prefixed with given trainable_scopes from collection of trainable variables. Useful during network fine tuning, when you only need to train subset of variables. """ if not trainable_scopes: return if isinstance(trainable_scopes, six.string_types): trainable_scopes = [ scope.strip() for scope in trainable_scopes.split(',') ] trainable_scopes = {scope for scope in trainable_scopes if scope} if not trainable_scopes: return trainable_collection = tf.get_collection_ref( tf.GraphKeys.TRAINABLE_VARIABLES) non_trainable_vars = [ v for v in trainable_collection if not any([v.op.name.startswith(s) for s in trainable_scopes]) ] for v in non_trainable_vars: trainable_collection.remove(v)
def test_graph_search_match_fail(self): """Tests graph search with linked bias tensors. In this code snippet two non adjacent bias tensors are linked together. There is no fisher block in kfac that matches this configuration, so the biases should not be registered. """ with tf.Graph().as_default(): tensor_dict = _build_model() layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(tensor_dict['out_0']) layer_collection.register_squared_error_loss(tensor_dict['out_1']) # TODO(b/69055612): remove this manual registration once layer_collection # implements register_fully_connected_multi. layer_collection.register_fully_connected( tensor_dict['w'], tensor_dict['x'], tensor_dict['pre_bias_0']) layer_collection.define_linked_parameters( (tensor_dict['b_0'], tensor_dict['b_1'])) with self.assertRaises(ValueError) as cm: gs.register_layers( layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)) self.assertIn('in linked group', str(cm.exception)) self.assertIn('was not matched', str(cm.exception)) self.assertIn( str(frozenset([tensor_dict['b_0'], tensor_dict['b_1']])), str(cm.exception))
def apply_mask(x, scope=''): """Apply mask to a given weight tensor. Args: x: Input weight tensor scope: The current variable scope."". Returns: Tensor representing masked_weights """ mask = pruning_utils.weight_mask_variable(x, scope) threshold = pruning_utils.weight_threshold_variable(x, scope) # Add masked_weights in the weights namescope so as to make it easier # for the quantization library to add quant ops. masked_weights = tf.multiply(mask, x, _MASKED_WEIGHT_NAME) # Make sure the mask for a given variable are not added multiple times to the # collection. This is particularly important when applying mask to an RNN. # weight variables if mask not in tf.get_collection_ref(_MASK_COLLECTION): tf.add_to_collection(_THRESHOLD_COLLECTION, threshold) tf.add_to_collection(_MASK_COLLECTION, mask) tf.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights) tf.add_to_collection(_WEIGHT_COLLECTION, x) return masked_weights
def test_tied_weights_untied_bias_registered_affine(self): """Test registering linked variables. Registering (w, b_1) as linked variables should not raise an error, since the matches with parameters (w) and (w, b_0) will be filtered out. """ with tf.Graph().as_default(): tensor_dict = _build_model() layer_collection_manual = lc.LayerCollection() layer_collection_manual.register_squared_error_loss( tensor_dict['out_0']) layer_collection_manual.register_squared_error_loss( tensor_dict['out_1']) layer_collection_manual.register_fully_connected( params=(tensor_dict['w'], tensor_dict['b_1']), inputs=tensor_dict['y'], outputs=tensor_dict['out_1']) layer_collection_manual.register_generic(tensor_dict['b_0'], batch_size=32) layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(tensor_dict['out_0']) layer_collection.register_squared_error_loss(tensor_dict['out_1']) layer_collection.define_linked_parameters( (tensor_dict['w'], tensor_dict['b_1'])) gs.register_layers(layer_collection, tf.get_collection_ref( tf.GraphKeys.GLOBAL_VARIABLES), batch_size=32) assert_fisher_blocks_match(self, layer_collection, layer_collection_manual)
def testApplyCustomizedLSTMMatrixCompression(self): pruning_interface.apply_customized_lstm_matrix_compression( self.compression_obj, self.mock_weight_params_fn, MockWeightInit, self.mock_lstmobj, self.wm_pc.shape, tf.float32) self.assertGreater(len(tf.get_collection_ref(pruning.MASK_COLLECTION)), 0)
def test_tied_weights_untied_bias_registered_weights(self): """Tests that graph search produces right solution on toy model.""" with tf.Graph().as_default(): tensor_dict = _build_model() layer_collection_manual = lc.LayerCollection() layer_collection_manual.register_squared_error_loss( tensor_dict['out_0']) layer_collection_manual.register_squared_error_loss( tensor_dict['out_1']) layer_collection_manual.register_fully_connected_multi( tensor_dict['w'], (tensor_dict['x'], tensor_dict['y']), (tensor_dict['pre_bias_0'], tensor_dict['pre_bias_1'])) layer_collection_manual.register_generic(tensor_dict['b_0'], batch_size=1) layer_collection_manual.register_generic(tensor_dict['b_1'], batch_size=1) layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(tensor_dict['out_0']) layer_collection.register_squared_error_loss(tensor_dict['out_1']) layer_collection.define_linked_parameters((tensor_dict['w'])) gs.register_layers(layer_collection, tf.get_collection_ref( tf.GraphKeys.GLOBAL_VARIABLES), batch_size=1) assert_fisher_blocks_match(self, layer_collection, layer_collection_manual)
def build(self, inputs_shape): # Call the build method of the parent class. super(MaskedLSTMCell, self).build(inputs_shape) self.built = False input_depth = inputs_shape.dims[1].value h_depth = self._num_units (self._mask, self._threshold, self._old_weight, self._old_old_weight, self._gradient) = _CreateLSTMPruneVariables(self, input_depth, h_depth) # Add masked_weights in the weights namescope so as to make it easier # for the quantization library to add quant ops. self._masked_kernel = tf.multiply(self._mask, self._kernel, pruning.MASKED_WEIGHT_NAME) if self._mask not in tf.get_collection_ref(pruning.MASK_COLLECTION): tf.add_to_collection(pruning.MASK_COLLECTION, self._mask) tf.add_to_collection(pruning.MASKED_WEIGHT_COLLECTION, self._masked_kernel) tf.add_to_collection(pruning.THRESHOLD_COLLECTION, self._threshold) tf.add_to_collection(pruning.WEIGHT_COLLECTION, self._kernel) tf.add_to_collection(pruning.OLD_WEIGHT_COLLECTION, self._old_weight) tf.add_to_collection(pruning.OLD_OLD_WEIGHT_COLLECTION, self._old_old_weight) tf.add_to_collection(pruning.WEIGHT_GRADIENT_COLLECTION, self._gradient) self.built = True
def test_multiple_weights(self): """Test that graph search provides desired registration on toy model. In this toy example we apply the same linear layer to two different inputs. This tests whether graph search can correctly group them. """ with tf.Graph().as_default(): w = tf.get_variable('W', [10, 10]) b_0 = tf.get_variable('b_0', [ 10, ]) x = tf.placeholder(tf.float32, shape=(32, 10)) y = tf.placeholder(tf.float32, shape=(32, 10)) out_0 = tf.matmul(x, w) + b_0 out_1 = tf.matmul(y, w) + b_0 layer_collection_manual = lc.LayerCollection() layer_collection_manual.register_fully_connected_multi( (w, b_0), (x, y), (out_0, out_1)) layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(out_0) layer_collection.register_squared_error_loss(out_1) gs.register_layers( layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)) assert_fisher_blocks_match(self, layer_collection, layer_collection_manual)
def LogAndSummarizeMetrics(metrics, use_streaming_mean=True): """Logs and summarizes metrics. Metrics are added to the LOGGING_OUTPUTS collection. Args: metrics: A dictionary of scalar metrics. use_streaming_mean: If true, the metrics will be averaged using a running mean. Returns: If use_streaming_mean is true, then this will be the op that you need to regularly call to update the running mean. Otherwise, this is a no-op. """ prefix = tf.get_default_graph().get_name_scope() if prefix: prefix += "/" logging_collection = tf.get_collection_ref(LOGGING_OUTPUTS) update_ops = [tf.no_op()] for name, value in metrics.items(): if use_streaming_mean: value, update_op = tf.metrics.mean(value) update_ops.append(update_op) logging_collection.append((prefix + name, value)) tf.summary.scalar(name, value) return tf.group(*update_ops)
def test_specify_approximation_shared_parameters(self): """Test specifying approximations with layers containing shared parameters. If linked parameters are identified along with an approximation, then that approximation should be used when registering those parameters. """ with tf.Graph().as_default(): tensor_dict = _build_model() layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(tensor_dict['out_0']) layer_collection.register_squared_error_loss(tensor_dict['out_1']) layer_collection.define_linked_parameters( tensor_dict['w'], approximation=lc.APPROX_KRONECKER_INDEP_NAME) layer_collection.define_linked_parameters( tensor_dict['b_0'], approximation=lc.APPROX_DIAGONAL_NAME) layer_collection.define_linked_parameters( tensor_dict['b_1'], approximation=lc.APPROX_FULL_NAME) gs.register_layers(layer_collection, tf.get_collection_ref( tf.GraphKeys.GLOBAL_VARIABLES), batch_size=1) self.assertIsInstance( layer_collection.fisher_blocks[tensor_dict['w']], fb.FullyConnectedMultiIndepFB) self.assertIsInstance( layer_collection.fisher_blocks[tensor_dict['b_0']], fb.NaiveDiagonalFB) self.assertIsInstance( layer_collection.fisher_blocks[tensor_dict['b_1']], fb.FullFB)
def body(self, features): hp = self.hparams is_distill = hp.distill_phase == "distill" targets = features["targets_raw"] targets = tf.squeeze(targets, [1, 2, 3]) one_hot_targets = tf.one_hot(targets, hp.num_classes, dtype=tf.float32) # Teacher Network with tf.variable_scope("teacher"): teacher_outputs = self.teacher_model.body(features) tf.logging.info("teacher output shape: %s" % teacher_outputs.get_shape()) teacher_outputs = tf.reduce_mean(teacher_outputs, axis=[1, 2]) teacher_logits = tf.layers.dense(teacher_outputs, hp.num_classes) teacher_task_xent = tf.nn.softmax_cross_entropy_with_logits_v2( labels=one_hot_targets, logits=teacher_logits) outputs = teacher_logits if is_distill: # Load teacher weights tf.train.init_from_checkpoint(hp.teacher_dir, {"teacher/": "teacher/"}) # Do not train the teacher trainable_vars = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) del trainable_vars[:] # Student Network if is_distill: with tf.variable_scope("student"): student_outputs = self.student_model.body(features) tf.logging.info( "student output shape: %s" % student_outputs.get_shape()) student_outputs = tf.reduce_mean(student_outputs, axis=[1, 2]) student_logits = tf.layers.dense(student_outputs, hp.num_classes) student_task_xent = tf.nn.softmax_cross_entropy_with_logits_v2( labels=one_hot_targets, logits=student_logits) teacher_targets = tf.nn.softmax(teacher_logits / hp.distill_temperature) student_distill_xent = tf.nn.softmax_cross_entropy_with_logits_v2( labels=tf.stop_gradient(teacher_targets), logits=student_logits / hp.distill_temperature) # scale soft target obj. to match hard target obj. scale student_distill_xent *= hp.distill_temperature**2 outputs = student_logits # Summaries tf.summary.scalar("distill_xent", student_distill_xent) if not is_distill: phase_loss = teacher_task_xent else: phase_loss = hp.task_balance * student_task_xent phase_loss += (1 - hp.task_balance) * student_distill_xent losses = {"training": phase_loss} outputs = tf.reshape(outputs, [-1, 1, 1, 1, outputs.shape[1]]) return outputs, losses
def after_run(self, run_context, run_values): golden_values = { t.name: v for t, v in zip(tf.get_collection_ref(COLLECTION), run_values.results) } logging.info('Recorded golden values for %s', golden_values.keys()) self._measurements.append(golden_values)
def fix_saver(collection_lists=None): # Workaround to prevent serialization warning by removing objects if collection_lists is None: try: # Try latest api l = tf.get_collection_ref("summary_tags") l4 = tf.get_collection_ref(tf.GraphKeys.GRAPH_CONFIG) except Exception: l = tf.get_collection("summary_tags") l4 = tf.get_collection(tf.GraphKeys.GRAPH_CONFIG) l_stags = list(l) l4_stags = list(l4) del l[:] del l4[:] try: # Try latest api l1 = tf.get_collection_ref(tf.GraphKeys.DATA_PREP) l2 = tf.get_collection_ref(tf.GraphKeys.DATA_AUG) except Exception: l1 = tf.get_collection(tf.GraphKeys.DATA_PREP) l2 = tf.get_collection(tf.GraphKeys.DATA_AUG) l1_dtags = list(l1) l2_dtags = list(l2) del l1[:] del l2[:] try: # Do not save exclude variables l3 = tf.get_collection_ref(tf.GraphKeys.EXCL_RESTORE_VARS) except Exception: l3 = tf.get_collection(tf.GraphKeys.EXCL_RESTORE_VARS) l3_tags = list(l3) del l3[:] return [l_stags, l1_dtags, l2_dtags, l3_tags, l4_stags] else: # 0.7+ workaround, restore values for t in collection_lists[0]: tf.add_to_collection("summary_tags", t) for t in collection_lists[4]: tf.add_to_collection(tf.GraphKeys.GRAPH_CONFIG, t) for t in collection_lists[1]: tf.add_to_collection(tf.GraphKeys.DATA_PREP, t) for t in collection_lists[2]: tf.add_to_collection(tf.GraphKeys.DATA_AUG, t) for t in collection_lists[3]: tf.add_to_collection(tf.GraphKeys.EXCL_RESTORE_VARS, t)
def _variable_tracking_custom_getter(getter, *args, **kwargs): """Custom getter that tracks variables created. This custom getter places any variables that `getter` creates into the `_all_variables` attribute of the `AbstractModule` that is on top of the module call stack. The module call stack is a graph-dependent stack that keeps track of the sonnet module call order. Note that this assumes that variables added appended to `tf.Graph` collections. This is a safe assumption to make because `tf.add_to_collection()` appends objects to collections, and `tf.Variable` uses `tf.add_to_collections()` to add itself to `tf.Graph` collections. Note that this assumes that all variables are added either the `tf.GraphKeys.GLOBAL_VARIABLES` or `tf.GraphKeys.LOCAL_VARIABLES` collection. Args: getter: The true getter or another custom getter. *args: See positional arguments for `tf.get_variable()`. **kwargs: See keyword arguments for `tf.get_variable()`. Returns: See docstring for `tf.get_variable()`. """ # Get the module that is calling `tf.get_variable()` module_stack = _MODULE_STACKS[tf.get_default_graph()] module = module_stack[-1] # Get lists of local and global variables. We use `tf.get_collection_ref()` # instead of `tf.get_collection()` to avoid copying the collections. local_variables = tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES) global_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) num_local_vars_before = len(local_variables) num_global_vars_before = len(global_variables) out = getter(*args, **kwargs) # Add any local or global variables that have been created to `module` # pylint: disable=protected-access module._all_variables.update(local_variables[num_local_vars_before:]) module._all_variables.update(global_variables[num_global_vars_before:]) # pylint: enable=protected-access return out
def after_run(self, run_context, run_values): # Strip the 'golden_' prefix before saving the data. golden_values = { t.name.split(PREFIX)[1]: v for t, v in zip(tf.get_collection_ref(COLLECTION), run_values.results) } logging.info('Recorded golden values for %s', golden_values.keys()) self._measurements.append(golden_values)
def get_variables(checkpoint_prefix): get_name = lambda x : x.name stripper = lambda x : x.strip(':0') rem_tuble= lambda x : x[0] complete_variables = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) variables_names = list(map(stripper,list(map(get_name,complete_variables)))) checkpoint_variables = list(map(rem_tuble,tf.train.list_variables(checkpoint_prefix))) crossed_variables = list(set(variables_names).intersection(set(checkpoint_variables))) indices = [variables_names.index(name) for name in crossed_variables] return [complete_variables[i] for i in indices]
def read(self, filepath, partial=False): if partial: vars_to_restore = set(tf.get_collection_ref( tf.GraphKeys.GLOBAL_VARIABLES)) vars_to_restore = vars_to_restore.intersection( tf.train.list_variables(filepath)) vars_to_restore = list(vars_to_restore) logging.warn( "Restoring graph partially. Only the following vars will be restored: " + str(vars_to_restore)) partial_saver = tf.train.Saver(vars_to_restore) partial_saver.restore(self.session, filepath) logging.info( "Model checkpoint partially restored from file: %s." % filepath) else: self.saver.restore(self.session, filepath) logging.info("Model checkpoint restored from file: %s." % filepath)
def test_tied_weights_untied_bias(self): """Tests that ambiguity in graph raises an error. Graph search will find several possible registrations containing w including (w, b_1) & (w, b_2). Without any instructions in form of linked tensors or manual registration it defaults to registering an error and suggesting that the user register (w) as a linked tensor. """ with tf.Graph().as_default(): tensor_dict = _build_model() layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(tensor_dict['out_0']) layer_collection.register_squared_error_loss(tensor_dict['out_1']) with self.assertRaises(gs.AmbiguousRegistrationError): gs.register_layers( layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))
def test_tied_weights_untied_bias_registered_bias(self): """Tests that ambiguity in graph raises value error. Graph search will find several possible registrations for tensors. In this registering b_1 as a linked variable will result in an error because there will remain an ambiguity on the other branch of the graph. """ with tf.Graph().as_default(): tensor_dict = _build_model() layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(tensor_dict['out_0']) layer_collection.register_squared_error_loss(tensor_dict['out_1']) layer_collection.define_linked_parameters((tensor_dict['b_1'])) with self.assertRaises(gs.AmbiguousRegistrationError): gs.register_layers( layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))
def apply_mask_and_return(x, scope='', prune_option='weight'): """Apply mask to a given weight tensor. Args: x: Input weight tensor scope: The current variable scope. Defaults to "". prune_option: pruning option. Defaults to 'weight'. option = 'first_order_gradient' means using |weight| * |first order gradient| for pruning. option = 'second_order_gradient' means using |weight| * |second order gradient| for pruning. Returns: masked_weights: a TensorFlow tensor representing masked weights. mask: a TensorFlow tensor representing the pruning mask. """ mask = pruning_utils.weight_mask_variable(x, scope) threshold = pruning_utils.weight_threshold_variable(x, scope) # Add masked_weights in the weights namescope so as to make it easier # for the quantization library to add quant ops. masked_weights = tf.multiply(mask, x, MASKED_WEIGHT_NAME) if prune_option in ('first_order_gradient', 'second_order_gradient'): # absolute value of gradients for gradient based pruning gradient = pruning_utils.weight_gradient_variable(x, scope) old_weight = pruning_utils.old_weight_variable(x, scope) old_old_weight = pruning_utils.old_old_weight_variable(x, scope) # Make sure the mask for a given variable are not added multiple times to the # collection. This is particularly important when applying mask to RNN's # weight variables if mask not in tf.get_collection_ref(MASK_COLLECTION): tf.add_to_collection(THRESHOLD_COLLECTION, threshold) tf.add_to_collection(MASK_COLLECTION, mask) tf.add_to_collection(MASKED_WEIGHT_COLLECTION, masked_weights) tf.add_to_collection(WEIGHT_COLLECTION, x) if prune_option in ('first_order_gradient', 'second_order_gradient'): tf.add_to_collection(WEIGHT_GRADIENT_COLLECTION, gradient) tf.add_to_collection(OLD_WEIGHT_COLLECTION, old_weight) tf.add_to_collection(OLD_OLD_WEIGHT_COLLECTION, old_old_weight) return [masked_weights, mask]
def test_multitower_multi_loss_function(self): """Test multitower setup with multiple loss functions. The automatic graph scanner should handle multiple loss functions per tower, as long as they're registered in a consistent order. """ with tf.Graph().as_default(): w_1 = tf.get_variable('w_1', shape=[10, 10]) b_1 = tf.get_variable('b_1', shape=[10]) w_2 = tf.get_variable('w_2', shape=[10, 10]) b_2 = tf.get_variable('b_2', shape=[10]) layer_collection = lc.LayerCollection() layer_collection_manual = lc.LayerCollection() for tower_num in range(5): x = tf.placeholder(tf.float32, shape=(32, 10)) logits_1 = tf.matmul(x, w_1) + b_1 logits_2 = tf.matmul(x, w_2) + b_2 if tower_num == 0: reuse = False else: reuse = True with tf.variable_scope('tower%d' % tower_num, reuse=reuse): for l in [layer_collection, layer_collection_manual]: l.register_categorical_predictive_distribution( logits_1, name='loss_1') l.register_categorical_predictive_distribution( logits_2, name='loss_2') layer_collection_manual.register_fully_connected( (w_1, b_1), x, logits_1) layer_collection_manual.register_fully_connected( (w_2, b_2), x, logits_2) gs.register_layers( layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)) assert_fisher_blocks_match(self, layer_collection, layer_collection_manual)
def get_train_estimator_spec(gan_model_fns, loss_fns, gan_loss_kwargs, optimizers, joint_train, is_on_tpu, gan_train_steps, add_summaries, run_config): """Estimator spec for train case.""" # Construct optimizers if arguments are callable. This has to be done inside # the model_fn, since constructable optimizers might create tf.Variables that # need to be added to the current tf.Graph. optimizers = _maybe_construct_optimizers(optimizers) if is_on_tpu: optimizers = _maybe_make_cross_shard_optimizers(optimizers) tpu_train_op, scalar_loss = _get_train_op(gan_model_fns, loss_fns, gan_loss_kwargs, optimizers, joint_train, gan_train_steps, add_summaries) gs_1 = tf.reshape(tf1.train.get_global_step(), [1]) losses = tf1.get_collection_ref(tf1.GraphKeys.LOSSES) loss_names = [l.name for l in losses] losses = [tf.reshape(l, [1]) for l in losses] def host_call_fn(step, *losses): step = step[0] with tf.summary.create_file_writer(run_config.model_dir, max_queue=run_config.tpu_config. iterations_per_loop).as_default(): with tf.summary.record_if(True): for n, l in zip(loss_names, losses): tf.summary.scalar(n, tf.reduce_mean(l), step=step) return tf1.summary.all_v2_summary_ops() host_call = (host_call_fn, [gs_1] + losses) return tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=scalar_loss, train_op=tpu_train_op, host_call=host_call)
def test_subset_weights_manual_registration(self): """Test that graph search provides desired registration on toy model. In this toy example we apply the same matmul op to two different inputs followed by adding a bias to one of the inputs. This tests whether graph search can correctly group them. """ with tf.Graph().as_default(): w = tf.get_variable('W', [10, 10]) b_0 = tf.get_variable('b_0', [ 10, ]) x = tf.placeholder(tf.float32, shape=(32, 10)) y = tf.placeholder(tf.float32, shape=(32, 10)) out_n1 = tf.matmul(x, w) out_0 = out_n1 + b_0 out_1 = tf.matmul(y, w) layer_collection_manual = lc.LayerCollection() layer_collection_manual.register_fully_connected_multi( w, (x, y), (out_n1, out_1)) layer_collection_manual.register_generic(b_0, batch_size=1) layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(out_0) layer_collection.register_squared_error_loss(out_1) layer_collection.define_linked_parameters(w) gs.register_layers(layer_collection, tf.get_collection_ref( tf.GraphKeys.GLOBAL_VARIABLES), batch_size=1) assert_fisher_blocks_match(self, layer_collection, layer_collection_manual)
def mixed_usage_test(self): """Tests that graph search raises error on mixed types usage for tensors. Tensors can be reused in various locations in the tensorflow graph. This occurs regularly in the case of recurrent models or models with parallel graphs. However the tensors must be used for the same operation in each location or graph search should raise an error. """ with tf.Graph().as_default(): w = tf.get_variable('W', [10, 10]) x = tf.placeholder(tf.float32, shape=(32, 10)) y = tf.placeholder(tf.float32, shape=(32, 10, 10)) out_0 = tf.matmul(x, w) # pylint: disable=unused-variable out_1 = y + w # pylint: disable=unused-variable layer_collection = lc.LayerCollection() with self.assertRaises(ValueError) as cm: gs.register_layers( layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)) self.assertIn('mixed record types', str(cm.exception))
def main(unused_args): tf.set_random_seed(FLAGS.seed) tf.get_variable_scope().set_use_resource(True) np.random.seed(FLAGS.seed) # Load the MNIST data and set up an iterator. mnist_data = input_data.read_data_sets(FLAGS.mnist, one_hot=False, validation_size=0) train_images = mnist_data.train.images test_images = mnist_data.test.images if FLAGS.input_mask_path: reader = tf.train.load_checkpoint(FLAGS.input_mask_path) input_mask = reader.get_tensor('layer1/mask') indices = np.sum(input_mask, axis=1) != 0 train_images = train_images[:, indices] test_images = test_images[:, indices] dataset = tf.data.Dataset.from_tensor_slices( (train_images, mnist_data.train.labels.astype(np.int32))) num_batches = mnist_data.train.images.shape[0] // FLAGS.batch_size dataset = dataset.shuffle(buffer_size=mnist_data.train.images.shape[0]) batched_dataset = dataset.repeat(FLAGS.num_epochs).batch(FLAGS.batch_size) iterator = batched_dataset.make_one_shot_iterator() test_dataset = tf.data.Dataset.from_tensor_slices( (test_images, mnist_data.test.labels.astype(np.int32))) num_test_images = mnist_data.test.images.shape[0] test_dataset = test_dataset.repeat(FLAGS.num_epochs).batch(num_test_images) test_iterator = test_dataset.make_one_shot_iterator() # Set up loss function. use_model_pruning = FLAGS.training_method != 'baseline' if FLAGS.network_type == 'fc': cross_entropy_train, _ = mnist_network_fc( iterator.get_next(), model_pruning=use_model_pruning) cross_entropy_test, accuracy_test = mnist_network_fc( test_iterator.get_next(), reuse=True, model_pruning=use_model_pruning) else: raise RuntimeError(FLAGS.network + ' is an unknown network type.') # Remove extra added ones. Current implementation adds the variables twice # to the collection. Improve this hacky thing. # TODO test the following with the convnet or any other network. if use_model_pruning: for k in ('masks', 'masked_weights', 'thresholds', 'kernel'): # del tf.get_collection_ref(k)[2] # del tf.get_collection_ref(k)[2] collection = tf.get_collection_ref(k) del collection[len(collection) // 2:] print(tf.get_collection_ref(k)) # Set up optimizer and update ops. global_step = tf.train.get_or_create_global_step() batch_per_epoch = mnist_data.train.images.shape[0] // FLAGS.batch_size if FLAGS.optimizer != 'adam': if not use_model_pruning: boundaries = [ int(round(s * batch_per_epoch)) for s in [60, 70, 80] ] else: boundaries = [ int(round(s * batch_per_epoch)) for s in [FLAGS.lr_drop_epoch, FLAGS.lr_drop_epoch + 20] ] learning_rate = tf.train.piecewise_constant( global_step, boundaries, values=[ FLAGS.learning_rate / (3.**i) for i in range(len(boundaries) + 1) ]) else: learning_rate = FLAGS.learning_rate if FLAGS.optimizer == 'adam': opt = tf.train.AdamOptimizer(FLAGS.learning_rate) elif FLAGS.optimizer == 'momentum': opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum, use_nesterov=FLAGS.use_nesterov) elif FLAGS.optimizer == 'sgd': opt = tf.train.GradientDescentOptimizer(learning_rate) else: raise RuntimeError(FLAGS.optimizer + ' is unknown optimizer type') custom_sparsities = { 'layer2': FLAGS.end_sparsity * FLAGS.sparsity_scale, 'layer3': FLAGS.end_sparsity * 0 } if FLAGS.training_method == 'set': # We override the train op to also update the mask. opt = sparse_optimizers.SparseSETOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal) elif FLAGS.training_method == 'static': # We override the train op to also update the mask. opt = sparse_optimizers.SparseStaticOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal) elif FLAGS.training_method == 'momentum': # We override the train op to also update the mask. opt = sparse_optimizers.SparseMomentumOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, grow_init=FLAGS.grow_init, drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=False) elif FLAGS.training_method == 'rigl': # We override the train op to also update the mask. opt = sparse_optimizers.SparseRigLOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal, initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=False) elif FLAGS.training_method == 'snip': opt = sparse_optimizers.SparseSnipOptimizer( opt, mask_init_method=FLAGS.mask_init_method, default_sparsity=FLAGS.end_sparsity, custom_sparsity_map=custom_sparsities, use_tpu=False) elif FLAGS.training_method in ('scratch', 'baseline', 'prune'): pass else: raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method) train_op = opt.minimize(cross_entropy_train, global_step=global_step) if FLAGS.training_method == 'prune': hparams_string = ( 'begin_pruning_step={0},sparsity_function_begin_step={0},' 'end_pruning_step={1},sparsity_function_end_step={1},' 'target_sparsity={2},pruning_frequency={3},' 'threshold_decay={4}'.format(FLAGS.prune_begin_step, FLAGS.prune_end_step, FLAGS.end_sparsity, FLAGS.pruning_frequency, FLAGS.threshold_decay)) pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string) pruning_hparams.set_hparam( 'weight_sparsity_map', ['{0}:{1}'.format(k, v) for k, v in custom_sparsities.items()]) print(pruning_hparams) pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) with tf.control_dependencies([train_op]): train_op = pruning_obj.conditional_mask_update_op() weight_sparsity_levels = pruning.get_weight_sparsity() global_sparsity = sparse_utils.calculate_sparsity(pruning.get_masks()) tf.summary.scalar('test_accuracy', accuracy_test) tf.summary.scalar('global_sparsity', global_sparsity) for k, v in zip(pruning.get_masks(), weight_sparsity_levels): tf.summary.scalar('sparsity/%s' % k.name, v) if FLAGS.training_method in ('prune', 'snip', 'baseline'): mask_init_op = tf.no_op() tf.logging.info('No mask is set, starting dense.') else: all_masks = pruning.get_masks() mask_init_op = sparse_utils.get_mask_init_fn(all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity, custom_sparsities) if FLAGS.save_model: saver = tf.train.Saver() init_op = tf.global_variables_initializer() hyper_params_string = '_'.join([ FLAGS.network_type, str(FLAGS.batch_size), str(FLAGS.learning_rate), str(FLAGS.momentum), FLAGS.optimizer, str(FLAGS.l2_scale), FLAGS.training_method, str(FLAGS.prune_begin_step), str(FLAGS.prune_end_step), str(FLAGS.end_sparsity), str(FLAGS.pruning_frequency), str(FLAGS.seed) ]) tf.io.gfile.makedirs(FLAGS.save_path) filename = os.path.join(FLAGS.save_path, hyper_params_string + '.txt') merged_summary_op = tf.summary.merge_all() # Run session. if not use_model_pruning: with tf.Session() as sess: summary_writer = tf.summary.FileWriter( FLAGS.save_path, graph=tf.get_default_graph()) print('Epoch', 'Epoch time', 'Test loss', 'Test accuracy') sess.run([init_op]) tic = time.time() with tf.io.gfile.GFile(filename, 'w') as outputfile: for i in range(FLAGS.num_epochs * num_batches): sess.run([train_op]) if (i % num_batches) == (-1 % num_batches): epoch_time = time.time() - tic loss, accuracy, summary = sess.run([ cross_entropy_test, accuracy_test, merged_summary_op ]) # Write logs at every test iteration. summary_writer.add_summary(summary, i) log_str = '%d, %.4f, %.4f, %.4f' % ( i // num_batches, epoch_time, loss, accuracy) print(log_str) print(log_str, file=outputfile) tic = time.time() if FLAGS.save_model: saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt')) else: with tf.Session() as sess: summary_writer = tf.summary.FileWriter( FLAGS.save_path, graph=tf.get_default_graph()) log_str = ','.join([ 'Epoch', 'Iteration', 'Test loss', 'Test accuracy', 'G_Sparsity', 'Sparsity Layer 0', 'Sparsity Layer 1' ]) sess.run(init_op) sess.run(mask_init_op) tic = time.time() mask_records = {} with tf.io.gfile.GFile(filename, 'w') as outputfile: print(log_str) print(log_str, file=outputfile) for i in range(FLAGS.num_epochs * num_batches): if (FLAGS.mask_record_frequency > 0 and i % FLAGS.mask_record_frequency == 0): mask_vals = sess.run(pruning.get_masks()) # Cast into bool to save space. mask_records[i] = [ a.astype(np.bool) for a in mask_vals ] sess.run([train_op]) weight_sparsity, global_sparsity_val = sess.run( [weight_sparsity_levels, global_sparsity]) if (i % num_batches) == (-1 % num_batches): epoch_time = time.time() - tic loss, accuracy, summary = sess.run([ cross_entropy_test, accuracy_test, merged_summary_op ]) # Write logs at every test iteration. summary_writer.add_summary(summary, i) log_str = '%d, %d, %.4f, %.4f, %.4f, %.4f, %.4f' % ( i // num_batches, i, loss, accuracy, global_sparsity_val, weight_sparsity[0], weight_sparsity[1]) print(log_str) print(log_str, file=outputfile) mask_vals = sess.run(pruning.get_masks()) if FLAGS.network_type == 'fc': sparsities, sizes = get_compressed_fc(mask_vals) print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' % (sparsities, sizes)) print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' % (sparsities, sizes), file=outputfile) tic = time.time() if FLAGS.save_model: saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt')) if mask_records: np.save(os.path.join(FLAGS.save_path, 'mask_records'), mask_records)
def create_train_op(total_loss, optimizer, global_step=_USE_GLOBAL_STEP, update_ops=None, variables_to_train=None, transform_grads_fn=None, gate_gradients=tf.train.Optimizer.GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, check_numerics=True): """Creates an `Operation` that evaluates the gradients and returns the loss. Args: total_loss: A `Tensor` representing the total loss. optimizer: A tf.Optimizer to use for computing the gradients. global_step: A `Tensor` representing the global step variable. If left as `_USE_GLOBAL_STEP`, then tf.train.global_step() is used. update_ops: An optional list of updates to execute. If `update_ops` is `None`, then the update ops are set to the contents of the `tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`, a warning will be displayed. variables_to_train: an optional list of variables to train. If None, it will default to all tf.compat.v1.trainable_variables(). transform_grads_fn: A function which takes a single argument, a list of gradient to variable pairs (tuples), performs any requested gradient updates, such as gradient clipping or multipliers, and returns the updated list. gate_gradients: How to gate the computation of gradients. See tf.Optimizer. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. colocate_gradients_with_ops: Whether or not to try colocating the gradients with the ops that generated them. check_numerics: Whether or not we apply check_numerics. Returns: A `Tensor` that when evaluated, computes the gradients and returns the total loss value. """ if global_step is _USE_GLOBAL_STEP: # pylint: disable=g-int-id-comparison # global_step can be None when passed into the optimizer in case we do not # want apply_gradients to factor that in. This is different from default # behaviour where we use the standard global step. global_step = tf.train.get_or_create_global_step() # Update ops use GraphKeys.UPDATE_OPS collection if update_ops is None. global_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) if update_ops is None: update_ops = global_update_ops else: update_ops = set(update_ops) if not global_update_ops.issubset(update_ops): tf.logging.warning('update_ops in create_train_op does not contain all the ' 'update_ops in GraphKeys.UPDATE_OPS') # Make sure update_ops are computed before total_loss. if update_ops: with tf.control_dependencies(update_ops): barrier = tf.no_op(name='update_barrier') with tf.control_dependencies([barrier]): total_loss = tf.identity(total_loss) if variables_to_train is None: # Default to tf.compat.v1.trainable_variables() variables_to_train = tf.trainable_variables() else: # Make sure that variables_to_train are in # tf.compat.v1.trainable_variables() for v in variables_to_train: assert v.trainable or v in tf.trainable_variables() assert variables_to_train # Create the gradients. Note that apply_gradients adds the gradient # computation to the current graph. grads = optimizer.compute_gradients( total_loss, variables_to_train, gate_gradients=gate_gradients, aggregation_method=aggregation_method, colocate_gradients_with_ops=colocate_gradients_with_ops) if transform_grads_fn: grads = transform_grads_fn(grads) # Create gradient updates. grad_updates = optimizer.apply_gradients(grads, global_step=global_step) with tf.name_scope('train_op'): # Make sure total_loss is valid. if check_numerics: total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan') # Ensure the train_tensor computes grad_updates. with tf.control_dependencies([grad_updates]): train_op = tf.identity(total_loss) # Add the operation used for training to the 'train_op' collection train_ops = tf.get_collection_ref(tf.GraphKeys.TRAIN_OP) if train_op not in train_ops: train_ops.append(train_op) return train_op
def test_rnn_multi(self): """Test automatic registration on a static RNN. The model tested here is designed for MNIST classification. To classify images using a recurrent neural network, we consider every image row as a sequence of pixels. Because MNIST image shape is 28*28px, we will then handle 28 sequences of 28 steps for every sample. """ with tf.Graph().as_default(): dtype = tf.float32 n_input = 28 # MNIST data input (img shape: 28*28) n_timesteps = 28 # timesteps n_hidden = 128 # hidden layer num of features n_classes = 10 # MNIST total classes (0-9 digits) x = tf.placeholder(dtype, [None, n_timesteps, n_input]) y = tf.placeholder(tf.int32, [None]) x_unstack = tf.unstack(x, n_timesteps, 1) w_input = tf.get_variable('w_input', shape=[n_input, n_hidden], dtype=dtype) b_input = tf.get_variable('b_input', shape=[n_hidden], dtype=dtype) w_recurrent = tf.get_variable('w_recurrent', shape=[n_hidden, n_hidden], dtype=dtype) b_recurrent = tf.get_variable('b_recurrent', shape=[n_hidden], dtype=dtype) w_output = tf.get_variable('w_output', shape=[n_hidden, n_classes], dtype=dtype) b_output = tf.get_variable('b_output', shape=[n_classes], dtype=dtype) layer_collection_manual = lc.LayerCollection() layer_collection_auto = lc.LayerCollection() a = tf.zeros(tf.convert_to_tensor( [tf.shape(x_unstack[0])[0], n_hidden]), dtype=dtype) # Here 'a' are the activations, 's' the pre-activations. a_list = [a] s_input_list = [] s_recurrent_list = [] s_list = [] s_out_list = [] cost = 0.0 for i in range(len(x_unstack)): input_ = x_unstack[i] s_in = tf.matmul(input_, w_input) + b_input s_rec = tf.matmul(a, w_recurrent) + b_recurrent s = s_in + s_rec s_input_list.append(s_in) s_recurrent_list.append(s_rec) s_list.append(s) a = tf.tanh(s) a_list.append(a) s_out = tf.matmul(a, w_output) + b_output s_out_list.append(s_out) if i == len(x_unstack) - 1: labels = y else: labels = tf.zeros([tf.shape(y)[0]], dtype=tf.int32) cost += tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=s_out, labels=labels)) layer_collection_manual.register_categorical_predictive_distribution( s_out) layer_collection_auto.register_categorical_predictive_distribution( s_out) layer_collection_manual.register_fully_connected_multi( (w_input, b_input), x_unstack, s_input_list) layer_collection_manual.register_fully_connected_multi( (w_recurrent, b_recurrent), a_list[:-1], s_recurrent_list) layer_collection_manual.register_fully_connected_multi( (w_output, b_output), a_list[1:], s_out_list) gs.register_layers( layer_collection_auto, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)) assert_fisher_blocks_match(self, layer_collection_manual, layer_collection_auto)
def test_specify_approximation(self): """Test specifying approximations. If linked parameters are identified along with an approximation, then that approximation should be used when registering those parameters. """ with tf.Graph().as_default(): w_0 = tf.get_variable('w_0', [10, 10]) w_1 = tf.get_variable('w_1', [10, 10]) b_0 = tf.get_variable('b_0', [10]) b_1 = tf.get_variable('b_1', [10]) x_0 = tf.placeholder(tf.float32, shape=(32, 10)) x_1 = tf.placeholder(tf.float32, shape=(32, 10)) pre_bias_0 = tf.matmul(x_0, w_0) pre_bias_1 = tf.matmul(x_1, w_1) out_0 = pre_bias_0 + b_0 # pylint: disable=unused-variable out_1 = pre_bias_1 + b_1 # pylint: disable=unused-variable # Group variables as affine layers. layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(out_0) layer_collection.register_squared_error_loss(out_1) layer_collection.define_linked_parameters( (w_0, b_0), approximation=lc.APPROX_KRONECKER_NAME) layer_collection.define_linked_parameters( (w_1, b_1), approximation=lc.APPROX_DIAGONAL_NAME) gs.register_layers(layer_collection, tf.get_collection_ref( tf.GraphKeys.GLOBAL_VARIABLES), batch_size=32) self.assertIsInstance(layer_collection.fisher_blocks[(w_0, b_0)], fb.FullyConnectedKFACBasicFB) self.assertIsInstance(layer_collection.fisher_blocks[(w_1, b_1)], fb.FullyConnectedDiagonalFB) # Group variables as linear layers and generic parameters. layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(out_0) layer_collection.register_squared_error_loss(out_1) layer_collection.define_linked_parameters( w_0, approximation=lc.APPROX_DIAGONAL_NAME) layer_collection.define_linked_parameters( b_0, approximation=lc.APPROX_DIAGONAL_NAME) layer_collection.define_linked_parameters( w_1, approximation=lc.APPROX_KRONECKER_NAME) layer_collection.define_linked_parameters( b_1, approximation=lc.APPROX_FULL_NAME) gs.register_layers(layer_collection, tf.get_collection_ref( tf.GraphKeys.GLOBAL_VARIABLES), batch_size=32) self.assertIsInstance(layer_collection.fisher_blocks[w_0], fb.FullyConnectedDiagonalFB) self.assertIsInstance(layer_collection.fisher_blocks[b_0], fb.NaiveDiagonalFB) self.assertIsInstance(layer_collection.fisher_blocks[w_1], fb.FullyConnectedKFACBasicFB) self.assertIsInstance(layer_collection.fisher_blocks[b_1], fb.FullFB)
def before_run(self, run_context): return tf.train.SessionRunArgs( fetches=tf.get_collection_ref(COLLECTION))
def LogAndSaveHParams(): """Logs and saves the operative parameters to the graph.""" hparams_str = gin.operative_config_str() logging.info("Config:\n%s", hparams_str) tf.get_collection_ref("operative_hparams").append(hparams_str)