def filter_trainable_variables(trainable_scopes):
    """Keep only trainable variables which are prefixed with given scopes.
  Args:
    trainable_scopes: either list of trainable scopes or string with comma
      separated list of trainable scopes.
  This function removes all variables which are not prefixed with given
  trainable_scopes from collection of trainable variables.
  Useful during network fine tuning, when you only need to train subset of
  variables.
  """
    if not trainable_scopes:
        return
    if isinstance(trainable_scopes, six.string_types):
        trainable_scopes = [
            scope.strip() for scope in trainable_scopes.split(',')
        ]
    trainable_scopes = {scope for scope in trainable_scopes if scope}
    if not trainable_scopes:
        return
    trainable_collection = tf.get_collection_ref(
        tf.GraphKeys.TRAINABLE_VARIABLES)
    non_trainable_vars = [
        v for v in trainable_collection
        if not any([v.op.name.startswith(s) for s in trainable_scopes])
    ]
    for v in non_trainable_vars:
        trainable_collection.remove(v)
    def test_graph_search_match_fail(self):
        """Tests graph search with linked bias tensors.

    In this code snippet two non adjacent bias tensors are linked together.
    There is no fisher block in kfac that matches this configuration, so the
    biases should not be registered.
    """
        with tf.Graph().as_default():
            tensor_dict = _build_model()

            layer_collection = lc.LayerCollection()
            layer_collection.register_squared_error_loss(tensor_dict['out_0'])
            layer_collection.register_squared_error_loss(tensor_dict['out_1'])

            # TODO(b/69055612): remove this manual registration once layer_collection
            # implements register_fully_connected_multi.
            layer_collection.register_fully_connected(
                tensor_dict['w'], tensor_dict['x'], tensor_dict['pre_bias_0'])
            layer_collection.define_linked_parameters(
                (tensor_dict['b_0'], tensor_dict['b_1']))

            with self.assertRaises(ValueError) as cm:
                gs.register_layers(
                    layer_collection,
                    tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))

            self.assertIn('in linked group', str(cm.exception))
            self.assertIn('was not matched', str(cm.exception))
            self.assertIn(
                str(frozenset([tensor_dict['b_0'], tensor_dict['b_1']])),
                str(cm.exception))
def apply_mask(x, scope=''):
    """Apply mask to a given weight tensor.

  Args:
    x: Input weight tensor
    scope: The current variable scope."".

  Returns:
    Tensor representing masked_weights
  """

    mask = pruning_utils.weight_mask_variable(x, scope)
    threshold = pruning_utils.weight_threshold_variable(x, scope)
    # Add masked_weights in the weights namescope so as to make it easier
    # for the quantization library to add quant ops.
    masked_weights = tf.multiply(mask, x, _MASKED_WEIGHT_NAME)

    # Make sure the mask for a given variable are not added multiple times to the
    # collection. This is particularly important when applying mask to an RNN.
    # weight variables
    if mask not in tf.get_collection_ref(_MASK_COLLECTION):
        tf.add_to_collection(_THRESHOLD_COLLECTION, threshold)
        tf.add_to_collection(_MASK_COLLECTION, mask)
        tf.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights)
        tf.add_to_collection(_WEIGHT_COLLECTION, x)
    return masked_weights
    def test_tied_weights_untied_bias_registered_affine(self):
        """Test registering linked variables.

    Registering (w, b_1) as linked variables should not raise an error, since
    the matches with parameters (w) and (w, b_0) will be filtered out.
    """
        with tf.Graph().as_default():
            tensor_dict = _build_model()

            layer_collection_manual = lc.LayerCollection()
            layer_collection_manual.register_squared_error_loss(
                tensor_dict['out_0'])
            layer_collection_manual.register_squared_error_loss(
                tensor_dict['out_1'])

            layer_collection_manual.register_fully_connected(
                params=(tensor_dict['w'], tensor_dict['b_1']),
                inputs=tensor_dict['y'],
                outputs=tensor_dict['out_1'])
            layer_collection_manual.register_generic(tensor_dict['b_0'],
                                                     batch_size=32)

            layer_collection = lc.LayerCollection()
            layer_collection.register_squared_error_loss(tensor_dict['out_0'])
            layer_collection.register_squared_error_loss(tensor_dict['out_1'])

            layer_collection.define_linked_parameters(
                (tensor_dict['w'], tensor_dict['b_1']))
            gs.register_layers(layer_collection,
                               tf.get_collection_ref(
                                   tf.GraphKeys.GLOBAL_VARIABLES),
                               batch_size=32)

            assert_fisher_blocks_match(self, layer_collection,
                                       layer_collection_manual)
Exemple #5
0
    def testApplyCustomizedLSTMMatrixCompression(self):
        pruning_interface.apply_customized_lstm_matrix_compression(
            self.compression_obj, self.mock_weight_params_fn, MockWeightInit,
            self.mock_lstmobj, self.wm_pc.shape, tf.float32)

        self.assertGreater(len(tf.get_collection_ref(pruning.MASK_COLLECTION)),
                           0)
    def test_tied_weights_untied_bias_registered_weights(self):
        """Tests that graph search produces right solution on toy model."""
        with tf.Graph().as_default():
            tensor_dict = _build_model()

            layer_collection_manual = lc.LayerCollection()
            layer_collection_manual.register_squared_error_loss(
                tensor_dict['out_0'])
            layer_collection_manual.register_squared_error_loss(
                tensor_dict['out_1'])

            layer_collection_manual.register_fully_connected_multi(
                tensor_dict['w'], (tensor_dict['x'], tensor_dict['y']),
                (tensor_dict['pre_bias_0'], tensor_dict['pre_bias_1']))
            layer_collection_manual.register_generic(tensor_dict['b_0'],
                                                     batch_size=1)
            layer_collection_manual.register_generic(tensor_dict['b_1'],
                                                     batch_size=1)

            layer_collection = lc.LayerCollection()
            layer_collection.register_squared_error_loss(tensor_dict['out_0'])
            layer_collection.register_squared_error_loss(tensor_dict['out_1'])

            layer_collection.define_linked_parameters((tensor_dict['w']))
            gs.register_layers(layer_collection,
                               tf.get_collection_ref(
                                   tf.GraphKeys.GLOBAL_VARIABLES),
                               batch_size=1)

            assert_fisher_blocks_match(self, layer_collection,
                                       layer_collection_manual)
Exemple #7
0
    def build(self, inputs_shape):
        # Call the build method of the parent class.
        super(MaskedLSTMCell, self).build(inputs_shape)

        self.built = False

        input_depth = inputs_shape.dims[1].value
        h_depth = self._num_units

        (self._mask, self._threshold, self._old_weight, self._old_old_weight,
         self._gradient) = _CreateLSTMPruneVariables(self, input_depth,
                                                     h_depth)
        # Add masked_weights in the weights namescope so as to make it easier
        # for the quantization library to add quant ops.
        self._masked_kernel = tf.multiply(self._mask, self._kernel,
                                          pruning.MASKED_WEIGHT_NAME)

        if self._mask not in tf.get_collection_ref(pruning.MASK_COLLECTION):
            tf.add_to_collection(pruning.MASK_COLLECTION, self._mask)
            tf.add_to_collection(pruning.MASKED_WEIGHT_COLLECTION,
                                 self._masked_kernel)
            tf.add_to_collection(pruning.THRESHOLD_COLLECTION, self._threshold)
            tf.add_to_collection(pruning.WEIGHT_COLLECTION, self._kernel)
            tf.add_to_collection(pruning.OLD_WEIGHT_COLLECTION,
                                 self._old_weight)
            tf.add_to_collection(pruning.OLD_OLD_WEIGHT_COLLECTION,
                                 self._old_old_weight)
            tf.add_to_collection(pruning.WEIGHT_GRADIENT_COLLECTION,
                                 self._gradient)

        self.built = True
    def test_multiple_weights(self):
        """Test that graph search provides desired registration on toy model.

    In this toy example we apply the same linear layer to two different inputs.
    This tests whether graph search can correctly group them.
    """
        with tf.Graph().as_default():
            w = tf.get_variable('W', [10, 10])
            b_0 = tf.get_variable('b_0', [
                10,
            ])
            x = tf.placeholder(tf.float32, shape=(32, 10))
            y = tf.placeholder(tf.float32, shape=(32, 10))

            out_0 = tf.matmul(x, w) + b_0
            out_1 = tf.matmul(y, w) + b_0

            layer_collection_manual = lc.LayerCollection()
            layer_collection_manual.register_fully_connected_multi(
                (w, b_0), (x, y), (out_0, out_1))

            layer_collection = lc.LayerCollection()
            layer_collection.register_squared_error_loss(out_0)
            layer_collection.register_squared_error_loss(out_1)

            gs.register_layers(
                layer_collection,
                tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))

            assert_fisher_blocks_match(self, layer_collection,
                                       layer_collection_manual)
def LogAndSummarizeMetrics(metrics, use_streaming_mean=True):
    """Logs and summarizes metrics.

  Metrics are added to the LOGGING_OUTPUTS collection.

  Args:
    metrics: A dictionary of scalar metrics.
    use_streaming_mean: If true, the metrics will be averaged using a running
      mean.

  Returns:
    If use_streaming_mean is true, then this will be the op that you need to
    regularly call to update the running mean. Otherwise, this is a no-op.
  """

    prefix = tf.get_default_graph().get_name_scope()
    if prefix:
        prefix += "/"
    logging_collection = tf.get_collection_ref(LOGGING_OUTPUTS)

    update_ops = [tf.no_op()]
    for name, value in metrics.items():
        if use_streaming_mean:
            value, update_op = tf.metrics.mean(value)
            update_ops.append(update_op)
        logging_collection.append((prefix + name, value))
        tf.summary.scalar(name, value)

    return tf.group(*update_ops)
Exemple #10
0
    def test_specify_approximation_shared_parameters(self):
        """Test specifying approximations with layers containing shared parameters.

    If linked parameters are identified along with an approximation, then
    that approximation should be used when registering those parameters.
    """
        with tf.Graph().as_default():
            tensor_dict = _build_model()

            layer_collection = lc.LayerCollection()
            layer_collection.register_squared_error_loss(tensor_dict['out_0'])
            layer_collection.register_squared_error_loss(tensor_dict['out_1'])

            layer_collection.define_linked_parameters(
                tensor_dict['w'], approximation=lc.APPROX_KRONECKER_INDEP_NAME)
            layer_collection.define_linked_parameters(
                tensor_dict['b_0'], approximation=lc.APPROX_DIAGONAL_NAME)
            layer_collection.define_linked_parameters(
                tensor_dict['b_1'], approximation=lc.APPROX_FULL_NAME)

            gs.register_layers(layer_collection,
                               tf.get_collection_ref(
                                   tf.GraphKeys.GLOBAL_VARIABLES),
                               batch_size=1)

            self.assertIsInstance(
                layer_collection.fisher_blocks[tensor_dict['w']],
                fb.FullyConnectedMultiIndepFB)
            self.assertIsInstance(
                layer_collection.fisher_blocks[tensor_dict['b_0']],
                fb.NaiveDiagonalFB)
            self.assertIsInstance(
                layer_collection.fisher_blocks[tensor_dict['b_1']], fb.FullFB)
  def body(self, features):
    hp = self.hparams
    is_distill = hp.distill_phase == "distill"

    targets = features["targets_raw"]
    targets = tf.squeeze(targets, [1, 2, 3])
    one_hot_targets = tf.one_hot(targets, hp.num_classes, dtype=tf.float32)

    # Teacher Network
    with tf.variable_scope("teacher"):
      teacher_outputs = self.teacher_model.body(features)
      tf.logging.info("teacher output shape: %s" % teacher_outputs.get_shape())
      teacher_outputs = tf.reduce_mean(teacher_outputs, axis=[1, 2])
      teacher_logits = tf.layers.dense(teacher_outputs, hp.num_classes)

      teacher_task_xent = tf.nn.softmax_cross_entropy_with_logits_v2(
          labels=one_hot_targets, logits=teacher_logits)
      outputs = teacher_logits

    if is_distill:
      # Load teacher weights
      tf.train.init_from_checkpoint(hp.teacher_dir, {"teacher/": "teacher/"})
      # Do not train the teacher
      trainable_vars = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
      del trainable_vars[:]

    # Student Network
    if is_distill:
      with tf.variable_scope("student"):
        student_outputs = self.student_model.body(features)
        tf.logging.info(
            "student output shape: %s" % student_outputs.get_shape())
        student_outputs = tf.reduce_mean(student_outputs, axis=[1, 2])
        student_logits = tf.layers.dense(student_outputs, hp.num_classes)

        student_task_xent = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=one_hot_targets, logits=student_logits)
        teacher_targets = tf.nn.softmax(teacher_logits / hp.distill_temperature)
        student_distill_xent = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=tf.stop_gradient(teacher_targets),
            logits=student_logits / hp.distill_temperature)
        # scale soft target obj. to match hard target obj. scale
        student_distill_xent *= hp.distill_temperature**2

        outputs = student_logits

        # Summaries
        tf.summary.scalar("distill_xent", student_distill_xent)

    if not is_distill:
      phase_loss = teacher_task_xent
    else:
      phase_loss = hp.task_balance * student_task_xent
      phase_loss += (1 - hp.task_balance) * student_distill_xent

    losses = {"training": phase_loss}
    outputs = tf.reshape(outputs, [-1, 1, 1, 1, outputs.shape[1]])

    return outputs, losses
Exemple #12
0
 def after_run(self, run_context, run_values):
     golden_values = {
         t.name: v
         for t, v in zip(tf.get_collection_ref(COLLECTION),
                         run_values.results)
     }
     logging.info('Recorded golden values for %s', golden_values.keys())
     self._measurements.append(golden_values)
Exemple #13
0
def fix_saver(collection_lists=None):
    # Workaround to prevent serialization warning by removing objects
    if collection_lists is None:
        try:
            # Try latest api
            l = tf.get_collection_ref("summary_tags")
            l4 = tf.get_collection_ref(tf.GraphKeys.GRAPH_CONFIG)
        except Exception:
            l = tf.get_collection("summary_tags")
            l4 = tf.get_collection(tf.GraphKeys.GRAPH_CONFIG)
        l_stags = list(l)
        l4_stags = list(l4)
        del l[:]
        del l4[:]

        try:
            # Try latest api
            l1 = tf.get_collection_ref(tf.GraphKeys.DATA_PREP)
            l2 = tf.get_collection_ref(tf.GraphKeys.DATA_AUG)
        except Exception:
            l1 = tf.get_collection(tf.GraphKeys.DATA_PREP)
            l2 = tf.get_collection(tf.GraphKeys.DATA_AUG)
        l1_dtags = list(l1)
        l2_dtags = list(l2)
        del l1[:]
        del l2[:]

        try:  # Do not save exclude variables
            l3 = tf.get_collection_ref(tf.GraphKeys.EXCL_RESTORE_VARS)
        except Exception:
            l3 = tf.get_collection(tf.GraphKeys.EXCL_RESTORE_VARS)
        l3_tags = list(l3)
        del l3[:]
        return [l_stags, l1_dtags, l2_dtags, l3_tags, l4_stags]
    else:
        # 0.7+ workaround, restore values
        for t in collection_lists[0]:
            tf.add_to_collection("summary_tags", t)
        for t in collection_lists[4]:
            tf.add_to_collection(tf.GraphKeys.GRAPH_CONFIG, t)
        for t in collection_lists[1]:
            tf.add_to_collection(tf.GraphKeys.DATA_PREP, t)
        for t in collection_lists[2]:
            tf.add_to_collection(tf.GraphKeys.DATA_AUG, t)
        for t in collection_lists[3]:
            tf.add_to_collection(tf.GraphKeys.EXCL_RESTORE_VARS, t)
Exemple #14
0
def _variable_tracking_custom_getter(getter, *args, **kwargs):
    """Custom getter that tracks variables created.

    This custom getter places any variables that `getter` creates into the
    `_all_variables` attribute of the `AbstractModule` that is on top of the
    module call stack. The module call stack is a graph-dependent stack that
    keeps track of the sonnet module call order.

    Note that this assumes that variables added appended to `tf.Graph`
    collections. This is a safe assumption to make because
    `tf.add_to_collection()` appends objects to collections, and `tf.Variable`
    uses `tf.add_to_collections()` to add itself to `tf.Graph` collections.

    Note that this assumes that all variables are added either the
    `tf.GraphKeys.GLOBAL_VARIABLES` or `tf.GraphKeys.LOCAL_VARIABLES` collection.

    Args:
      getter: The true getter or another custom getter.
      *args: See positional arguments for `tf.get_variable()`.
      **kwargs: See keyword arguments for `tf.get_variable()`.

    Returns:
      See docstring for `tf.get_variable()`.
    """
    # Get the module that is calling `tf.get_variable()`
    module_stack = _MODULE_STACKS[tf.get_default_graph()]
    module = module_stack[-1]

    # Get lists of local and global variables. We use `tf.get_collection_ref()`
    # instead of `tf.get_collection()` to avoid copying the collections.
    local_variables = tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES)
    global_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)

    num_local_vars_before = len(local_variables)
    num_global_vars_before = len(global_variables)

    out = getter(*args, **kwargs)

    # Add any local or global variables that have been created to `module`
    # pylint: disable=protected-access
    module._all_variables.update(local_variables[num_local_vars_before:])
    module._all_variables.update(global_variables[num_global_vars_before:])
    # pylint: enable=protected-access

    return out
Exemple #15
0
 def after_run(self, run_context, run_values):
     # Strip the 'golden_' prefix before saving the data.
     golden_values = {
         t.name.split(PREFIX)[1]: v
         for t, v in zip(tf.get_collection_ref(COLLECTION),
                         run_values.results)
     }
     logging.info('Recorded golden values for %s', golden_values.keys())
     self._measurements.append(golden_values)
def get_variables(checkpoint_prefix):
  
  get_name = lambda x : x.name 
  stripper = lambda x : x.strip(':0')
  rem_tuble= lambda x : x[0]


  complete_variables   = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
  
  variables_names      = list(map(stripper,list(map(get_name,complete_variables))))
  checkpoint_variables = list(map(rem_tuble,tf.train.list_variables(checkpoint_prefix)))

  crossed_variables  = list(set(variables_names).intersection(set(checkpoint_variables)))
  indices = [variables_names.index(name) for name in crossed_variables] 
  return [complete_variables[i] for i in indices]
Exemple #17
0
    def read(self, filepath, partial=False):
        if partial:
            vars_to_restore = set(tf.get_collection_ref(
                tf.GraphKeys.GLOBAL_VARIABLES))
            vars_to_restore = vars_to_restore.intersection(
                tf.train.list_variables(filepath))
            vars_to_restore = list(vars_to_restore)
            logging.warn(
                "Restoring graph partially. Only the following vars will be restored: " + str(vars_to_restore))

            partial_saver = tf.train.Saver(vars_to_restore)
            partial_saver.restore(self.session, filepath)
            logging.info(
                "Model checkpoint partially restored from file: %s." % filepath)
        else:
            self.saver.restore(self.session, filepath)
            logging.info("Model checkpoint restored from file: %s." % filepath)
Exemple #18
0
    def test_tied_weights_untied_bias(self):
        """Tests that ambiguity in graph raises an error.

    Graph search will find several possible registrations containing w including
    (w, b_1) & (w, b_2). Without any instructions in form of linked tensors or
    manual registration it defaults to registering an error and suggesting that
    the user register (w) as a linked tensor.
    """
        with tf.Graph().as_default():
            tensor_dict = _build_model()

            layer_collection = lc.LayerCollection()
            layer_collection.register_squared_error_loss(tensor_dict['out_0'])
            layer_collection.register_squared_error_loss(tensor_dict['out_1'])

            with self.assertRaises(gs.AmbiguousRegistrationError):
                gs.register_layers(
                    layer_collection,
                    tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))
Exemple #19
0
    def test_tied_weights_untied_bias_registered_bias(self):
        """Tests that ambiguity in graph raises value error.

    Graph search will find several possible registrations for tensors.
    In this registering b_1 as a linked variable will result in an error
    because there will remain an ambiguity on the other branch of the graph.
    """
        with tf.Graph().as_default():
            tensor_dict = _build_model()

            layer_collection = lc.LayerCollection()
            layer_collection.register_squared_error_loss(tensor_dict['out_0'])
            layer_collection.register_squared_error_loss(tensor_dict['out_1'])

            layer_collection.define_linked_parameters((tensor_dict['b_1']))

            with self.assertRaises(gs.AmbiguousRegistrationError):
                gs.register_layers(
                    layer_collection,
                    tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))
Exemple #20
0
def apply_mask_and_return(x, scope='', prune_option='weight'):
  """Apply mask to a given weight tensor.

  Args:
    x: Input weight tensor
    scope: The current variable scope. Defaults to "".
    prune_option: pruning option. Defaults to 'weight'. option =
      'first_order_gradient' means using |weight| * |first order gradient| for
      pruning. option = 'second_order_gradient' means using |weight| * |second
      order gradient| for pruning.

  Returns:
    masked_weights: a TensorFlow tensor representing masked weights.
    mask: a TensorFlow tensor representing the pruning mask.
  """

  mask = pruning_utils.weight_mask_variable(x, scope)
  threshold = pruning_utils.weight_threshold_variable(x, scope)
  # Add masked_weights in the weights namescope so as to make it easier
  # for the quantization library to add quant ops.
  masked_weights = tf.multiply(mask, x, MASKED_WEIGHT_NAME)

  if prune_option in ('first_order_gradient', 'second_order_gradient'):
    # absolute value of gradients for gradient based pruning
    gradient = pruning_utils.weight_gradient_variable(x, scope)
    old_weight = pruning_utils.old_weight_variable(x, scope)
    old_old_weight = pruning_utils.old_old_weight_variable(x, scope)

  # Make sure the mask for a given variable are not added multiple times to the
  # collection. This is particularly important when applying mask to RNN's
  # weight variables
  if mask not in tf.get_collection_ref(MASK_COLLECTION):
    tf.add_to_collection(THRESHOLD_COLLECTION, threshold)
    tf.add_to_collection(MASK_COLLECTION, mask)
    tf.add_to_collection(MASKED_WEIGHT_COLLECTION, masked_weights)
    tf.add_to_collection(WEIGHT_COLLECTION, x)
    if prune_option in ('first_order_gradient', 'second_order_gradient'):
      tf.add_to_collection(WEIGHT_GRADIENT_COLLECTION, gradient)
      tf.add_to_collection(OLD_WEIGHT_COLLECTION, old_weight)
      tf.add_to_collection(OLD_OLD_WEIGHT_COLLECTION, old_old_weight)
  return [masked_weights, mask]
Exemple #21
0
    def test_multitower_multi_loss_function(self):
        """Test multitower setup with multiple loss functions.

    The automatic graph scanner should handle multiple loss functions per tower,
    as long as they're registered in a consistent order.
    """
        with tf.Graph().as_default():
            w_1 = tf.get_variable('w_1', shape=[10, 10])
            b_1 = tf.get_variable('b_1', shape=[10])
            w_2 = tf.get_variable('w_2', shape=[10, 10])
            b_2 = tf.get_variable('b_2', shape=[10])
            layer_collection = lc.LayerCollection()
            layer_collection_manual = lc.LayerCollection()
            for tower_num in range(5):
                x = tf.placeholder(tf.float32, shape=(32, 10))
                logits_1 = tf.matmul(x, w_1) + b_1
                logits_2 = tf.matmul(x, w_2) + b_2
                if tower_num == 0:
                    reuse = False
                else:
                    reuse = True
                with tf.variable_scope('tower%d' % tower_num, reuse=reuse):
                    for l in [layer_collection, layer_collection_manual]:
                        l.register_categorical_predictive_distribution(
                            logits_1, name='loss_1')
                        l.register_categorical_predictive_distribution(
                            logits_2, name='loss_2')
                    layer_collection_manual.register_fully_connected(
                        (w_1, b_1), x, logits_1)
                    layer_collection_manual.register_fully_connected(
                        (w_2, b_2), x, logits_2)

            gs.register_layers(
                layer_collection,
                tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))

            assert_fisher_blocks_match(self, layer_collection,
                                       layer_collection_manual)
Exemple #22
0
def get_train_estimator_spec(gan_model_fns, loss_fns, gan_loss_kwargs,
                             optimizers, joint_train, is_on_tpu,
                             gan_train_steps, add_summaries, run_config):
    """Estimator spec for train case."""
    # Construct optimizers if arguments are callable. This has to be done inside
    # the model_fn, since constructable optimizers might create tf.Variables that
    # need to be added to the current tf.Graph.
    optimizers = _maybe_construct_optimizers(optimizers)
    if is_on_tpu:
        optimizers = _maybe_make_cross_shard_optimizers(optimizers)

    tpu_train_op, scalar_loss = _get_train_op(gan_model_fns, loss_fns,
                                              gan_loss_kwargs, optimizers,
                                              joint_train, gan_train_steps,
                                              add_summaries)

    gs_1 = tf.reshape(tf1.train.get_global_step(), [1])
    losses = tf1.get_collection_ref(tf1.GraphKeys.LOSSES)
    loss_names = [l.name for l in losses]
    losses = [tf.reshape(l, [1]) for l in losses]

    def host_call_fn(step, *losses):
        step = step[0]
        with tf.summary.create_file_writer(run_config.model_dir,
                                           max_queue=run_config.tpu_config.
                                           iterations_per_loop).as_default():
            with tf.summary.record_if(True):
                for n, l in zip(loss_names, losses):
                    tf.summary.scalar(n, tf.reduce_mean(l), step=step)
                return tf1.summary.all_v2_summary_ops()

    host_call = (host_call_fn, [gs_1] + losses)

    return tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=scalar_loss,
        train_op=tpu_train_op,
        host_call=host_call)
Exemple #23
0
    def test_subset_weights_manual_registration(self):
        """Test that graph search provides desired registration on toy model.

    In this toy example we apply the same matmul op to two different inputs
    followed by adding a bias to one of the inputs. This tests whether graph
    search can correctly group them.
    """
        with tf.Graph().as_default():
            w = tf.get_variable('W', [10, 10])
            b_0 = tf.get_variable('b_0', [
                10,
            ])
            x = tf.placeholder(tf.float32, shape=(32, 10))
            y = tf.placeholder(tf.float32, shape=(32, 10))

            out_n1 = tf.matmul(x, w)
            out_0 = out_n1 + b_0
            out_1 = tf.matmul(y, w)

            layer_collection_manual = lc.LayerCollection()
            layer_collection_manual.register_fully_connected_multi(
                w, (x, y), (out_n1, out_1))
            layer_collection_manual.register_generic(b_0, batch_size=1)

            layer_collection = lc.LayerCollection()
            layer_collection.register_squared_error_loss(out_0)
            layer_collection.register_squared_error_loss(out_1)

            layer_collection.define_linked_parameters(w)
            gs.register_layers(layer_collection,
                               tf.get_collection_ref(
                                   tf.GraphKeys.GLOBAL_VARIABLES),
                               batch_size=1)

            assert_fisher_blocks_match(self, layer_collection,
                                       layer_collection_manual)
Exemple #24
0
    def mixed_usage_test(self):
        """Tests that graph search raises error on mixed types usage for tensors.

    Tensors can be reused in various locations in the tensorflow graph. This
    occurs regularly in the case of recurrent models or models with parallel
    graphs. However the tensors must be used for the same operation in each
    location or graph search should raise an error.
    """
        with tf.Graph().as_default():
            w = tf.get_variable('W', [10, 10])
            x = tf.placeholder(tf.float32, shape=(32, 10))
            y = tf.placeholder(tf.float32, shape=(32, 10, 10))

            out_0 = tf.matmul(x, w)  # pylint: disable=unused-variable
            out_1 = y + w  # pylint: disable=unused-variable

            layer_collection = lc.LayerCollection()

            with self.assertRaises(ValueError) as cm:
                gs.register_layers(
                    layer_collection,
                    tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))

            self.assertIn('mixed record types', str(cm.exception))
def main(unused_args):
    tf.set_random_seed(FLAGS.seed)
    tf.get_variable_scope().set_use_resource(True)
    np.random.seed(FLAGS.seed)

    # Load the MNIST data and set up an iterator.
    mnist_data = input_data.read_data_sets(FLAGS.mnist,
                                           one_hot=False,
                                           validation_size=0)
    train_images = mnist_data.train.images
    test_images = mnist_data.test.images
    if FLAGS.input_mask_path:
        reader = tf.train.load_checkpoint(FLAGS.input_mask_path)
        input_mask = reader.get_tensor('layer1/mask')
        indices = np.sum(input_mask, axis=1) != 0
        train_images = train_images[:, indices]
        test_images = test_images[:, indices]
    dataset = tf.data.Dataset.from_tensor_slices(
        (train_images, mnist_data.train.labels.astype(np.int32)))
    num_batches = mnist_data.train.images.shape[0] // FLAGS.batch_size
    dataset = dataset.shuffle(buffer_size=mnist_data.train.images.shape[0])
    batched_dataset = dataset.repeat(FLAGS.num_epochs).batch(FLAGS.batch_size)
    iterator = batched_dataset.make_one_shot_iterator()

    test_dataset = tf.data.Dataset.from_tensor_slices(
        (test_images, mnist_data.test.labels.astype(np.int32)))
    num_test_images = mnist_data.test.images.shape[0]
    test_dataset = test_dataset.repeat(FLAGS.num_epochs).batch(num_test_images)
    test_iterator = test_dataset.make_one_shot_iterator()

    # Set up loss function.
    use_model_pruning = FLAGS.training_method != 'baseline'

    if FLAGS.network_type == 'fc':
        cross_entropy_train, _ = mnist_network_fc(
            iterator.get_next(), model_pruning=use_model_pruning)
        cross_entropy_test, accuracy_test = mnist_network_fc(
            test_iterator.get_next(),
            reuse=True,
            model_pruning=use_model_pruning)
    else:
        raise RuntimeError(FLAGS.network + ' is an unknown network type.')

    # Remove extra added ones. Current implementation adds the variables twice
    # to the collection. Improve this hacky thing.
    # TODO test the following with the convnet or any other network.
    if use_model_pruning:
        for k in ('masks', 'masked_weights', 'thresholds', 'kernel'):
            # del tf.get_collection_ref(k)[2]
            # del tf.get_collection_ref(k)[2]
            collection = tf.get_collection_ref(k)
            del collection[len(collection) // 2:]
            print(tf.get_collection_ref(k))

    # Set up optimizer and update ops.
    global_step = tf.train.get_or_create_global_step()
    batch_per_epoch = mnist_data.train.images.shape[0] // FLAGS.batch_size

    if FLAGS.optimizer != 'adam':
        if not use_model_pruning:
            boundaries = [
                int(round(s * batch_per_epoch)) for s in [60, 70, 80]
            ]
        else:
            boundaries = [
                int(round(s * batch_per_epoch))
                for s in [FLAGS.lr_drop_epoch, FLAGS.lr_drop_epoch + 20]
            ]
        learning_rate = tf.train.piecewise_constant(
            global_step,
            boundaries,
            values=[
                FLAGS.learning_rate / (3.**i)
                for i in range(len(boundaries) + 1)
            ])
    else:
        learning_rate = FLAGS.learning_rate

    if FLAGS.optimizer == 'adam':
        opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
    elif FLAGS.optimizer == 'momentum':
        opt = tf.train.MomentumOptimizer(learning_rate,
                                         FLAGS.momentum,
                                         use_nesterov=FLAGS.use_nesterov)
    elif FLAGS.optimizer == 'sgd':
        opt = tf.train.GradientDescentOptimizer(learning_rate)
    else:
        raise RuntimeError(FLAGS.optimizer + ' is unknown optimizer type')
    custom_sparsities = {
        'layer2': FLAGS.end_sparsity * FLAGS.sparsity_scale,
        'layer3': FLAGS.end_sparsity * 0
    }

    if FLAGS.training_method == 'set':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseSETOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal)
    elif FLAGS.training_method == 'static':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseStaticOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal)
    elif FLAGS.training_method == 'momentum':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseMomentumOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            momentum=FLAGS.s_momentum,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            grow_init=FLAGS.grow_init,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            use_tpu=False)
    elif FLAGS.training_method == 'rigl':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseRigLOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            initial_acc_scale=FLAGS.rigl_acc_scale,
            use_tpu=False)
    elif FLAGS.training_method == 'snip':
        opt = sparse_optimizers.SparseSnipOptimizer(
            opt,
            mask_init_method=FLAGS.mask_init_method,
            default_sparsity=FLAGS.end_sparsity,
            custom_sparsity_map=custom_sparsities,
            use_tpu=False)
    elif FLAGS.training_method in ('scratch', 'baseline', 'prune'):
        pass
    else:
        raise ValueError('Unsupported pruning method: %s' %
                         FLAGS.training_method)

    train_op = opt.minimize(cross_entropy_train, global_step=global_step)

    if FLAGS.training_method == 'prune':
        hparams_string = (
            'begin_pruning_step={0},sparsity_function_begin_step={0},'
            'end_pruning_step={1},sparsity_function_end_step={1},'
            'target_sparsity={2},pruning_frequency={3},'
            'threshold_decay={4}'.format(FLAGS.prune_begin_step,
                                         FLAGS.prune_end_step,
                                         FLAGS.end_sparsity,
                                         FLAGS.pruning_frequency,
                                         FLAGS.threshold_decay))
        pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)
        pruning_hparams.set_hparam(
            'weight_sparsity_map',
            ['{0}:{1}'.format(k, v) for k, v in custom_sparsities.items()])
        print(pruning_hparams)
        pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)
        with tf.control_dependencies([train_op]):
            train_op = pruning_obj.conditional_mask_update_op()
    weight_sparsity_levels = pruning.get_weight_sparsity()
    global_sparsity = sparse_utils.calculate_sparsity(pruning.get_masks())
    tf.summary.scalar('test_accuracy', accuracy_test)
    tf.summary.scalar('global_sparsity', global_sparsity)
    for k, v in zip(pruning.get_masks(), weight_sparsity_levels):
        tf.summary.scalar('sparsity/%s' % k.name, v)
    if FLAGS.training_method in ('prune', 'snip', 'baseline'):
        mask_init_op = tf.no_op()
        tf.logging.info('No mask is set, starting dense.')
    else:
        all_masks = pruning.get_masks()
        mask_init_op = sparse_utils.get_mask_init_fn(all_masks,
                                                     FLAGS.mask_init_method,
                                                     FLAGS.end_sparsity,
                                                     custom_sparsities)

    if FLAGS.save_model:
        saver = tf.train.Saver()
    init_op = tf.global_variables_initializer()
    hyper_params_string = '_'.join([
        FLAGS.network_type,
        str(FLAGS.batch_size),
        str(FLAGS.learning_rate),
        str(FLAGS.momentum), FLAGS.optimizer,
        str(FLAGS.l2_scale), FLAGS.training_method,
        str(FLAGS.prune_begin_step),
        str(FLAGS.prune_end_step),
        str(FLAGS.end_sparsity),
        str(FLAGS.pruning_frequency),
        str(FLAGS.seed)
    ])
    tf.io.gfile.makedirs(FLAGS.save_path)
    filename = os.path.join(FLAGS.save_path, hyper_params_string + '.txt')
    merged_summary_op = tf.summary.merge_all()

    # Run session.
    if not use_model_pruning:
        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(
                FLAGS.save_path, graph=tf.get_default_graph())
            print('Epoch', 'Epoch time', 'Test loss', 'Test accuracy')
            sess.run([init_op])
            tic = time.time()
            with tf.io.gfile.GFile(filename, 'w') as outputfile:
                for i in range(FLAGS.num_epochs * num_batches):
                    sess.run([train_op])

                    if (i % num_batches) == (-1 % num_batches):
                        epoch_time = time.time() - tic
                        loss, accuracy, summary = sess.run([
                            cross_entropy_test, accuracy_test,
                            merged_summary_op
                        ])
                        # Write logs at every test iteration.
                        summary_writer.add_summary(summary, i)
                        log_str = '%d, %.4f, %.4f, %.4f' % (
                            i // num_batches, epoch_time, loss, accuracy)
                        print(log_str)
                        print(log_str, file=outputfile)
                        tic = time.time()
            if FLAGS.save_model:
                saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt'))
    else:
        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(
                FLAGS.save_path, graph=tf.get_default_graph())
            log_str = ','.join([
                'Epoch', 'Iteration', 'Test loss', 'Test accuracy',
                'G_Sparsity', 'Sparsity Layer 0', 'Sparsity Layer 1'
            ])
            sess.run(init_op)
            sess.run(mask_init_op)
            tic = time.time()
            mask_records = {}
            with tf.io.gfile.GFile(filename, 'w') as outputfile:
                print(log_str)
                print(log_str, file=outputfile)
                for i in range(FLAGS.num_epochs * num_batches):
                    if (FLAGS.mask_record_frequency > 0
                            and i % FLAGS.mask_record_frequency == 0):
                        mask_vals = sess.run(pruning.get_masks())
                        # Cast into bool to save space.
                        mask_records[i] = [
                            a.astype(np.bool) for a in mask_vals
                        ]
                    sess.run([train_op])
                    weight_sparsity, global_sparsity_val = sess.run(
                        [weight_sparsity_levels, global_sparsity])
                    if (i % num_batches) == (-1 % num_batches):
                        epoch_time = time.time() - tic
                        loss, accuracy, summary = sess.run([
                            cross_entropy_test, accuracy_test,
                            merged_summary_op
                        ])
                        # Write logs at every test iteration.
                        summary_writer.add_summary(summary, i)
                        log_str = '%d, %d, %.4f, %.4f, %.4f, %.4f, %.4f' % (
                            i // num_batches, i, loss, accuracy,
                            global_sparsity_val, weight_sparsity[0],
                            weight_sparsity[1])
                        print(log_str)
                        print(log_str, file=outputfile)
                        mask_vals = sess.run(pruning.get_masks())
                        if FLAGS.network_type == 'fc':
                            sparsities, sizes = get_compressed_fc(mask_vals)
                            print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' %
                                  (sparsities, sizes))
                            print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' %
                                  (sparsities, sizes),
                                  file=outputfile)
                        tic = time.time()
            if FLAGS.save_model:
                saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt'))
            if mask_records:
                np.save(os.path.join(FLAGS.save_path, 'mask_records'),
                        mask_records)
Exemple #26
0
def create_train_op(total_loss,
                    optimizer,
                    global_step=_USE_GLOBAL_STEP,
                    update_ops=None,
                    variables_to_train=None,
                    transform_grads_fn=None,
                    gate_gradients=tf.train.Optimizer.GATE_OP,
                    aggregation_method=None,
                    colocate_gradients_with_ops=False,
                    check_numerics=True):
  """Creates an `Operation` that evaluates the gradients and returns the loss.

  Args:
    total_loss: A `Tensor` representing the total loss.
    optimizer: A tf.Optimizer to use for computing the gradients.
    global_step: A `Tensor` representing the global step variable. If left as
      `_USE_GLOBAL_STEP`, then tf.train.global_step() is used.
    update_ops: An optional list of updates to execute. If `update_ops` is
      `None`, then the update ops are set to the contents of the
      `tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but
      it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`, a
      warning will be displayed.
    variables_to_train: an optional list of variables to train. If None, it will
      default to all tf.compat.v1.trainable_variables().
    transform_grads_fn: A function which takes a single argument, a list of
      gradient to variable pairs (tuples), performs any requested gradient
      updates, such as gradient clipping or multipliers, and returns the updated
      list.
    gate_gradients: How to gate the computation of gradients. See tf.Optimizer.
    aggregation_method: Specifies the method used to combine gradient terms.
      Valid values are defined in the class `AggregationMethod`.
    colocate_gradients_with_ops: Whether or not to try colocating the gradients
      with the ops that generated them.
    check_numerics: Whether or not we apply check_numerics.

  Returns:
    A `Tensor` that when evaluated, computes the gradients and returns the total
      loss value.
  """
  if global_step is _USE_GLOBAL_STEP:  # pylint: disable=g-int-id-comparison
    # global_step can be None when passed into the optimizer in case we do not
    # want apply_gradients to factor that in. This is different from default
    # behaviour where we use the standard global step.
    global_step = tf.train.get_or_create_global_step()

  # Update ops use GraphKeys.UPDATE_OPS collection if update_ops is None.
  global_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
  if update_ops is None:
    update_ops = global_update_ops
  else:
    update_ops = set(update_ops)
  if not global_update_ops.issubset(update_ops):
    tf.logging.warning('update_ops in create_train_op does not contain all the '
                       'update_ops in GraphKeys.UPDATE_OPS')

  # Make sure update_ops are computed before total_loss.
  if update_ops:
    with tf.control_dependencies(update_ops):
      barrier = tf.no_op(name='update_barrier')
    with tf.control_dependencies([barrier]):
      total_loss = tf.identity(total_loss)

  if variables_to_train is None:
    # Default to tf.compat.v1.trainable_variables()
    variables_to_train = tf.trainable_variables()
  else:
    # Make sure that variables_to_train are in
    # tf.compat.v1.trainable_variables()
    for v in variables_to_train:
      assert v.trainable or v in tf.trainable_variables()

  assert variables_to_train

  # Create the gradients. Note that apply_gradients adds the gradient
  # computation to the current graph.
  grads = optimizer.compute_gradients(
      total_loss,
      variables_to_train,
      gate_gradients=gate_gradients,
      aggregation_method=aggregation_method,
      colocate_gradients_with_ops=colocate_gradients_with_ops)

  if transform_grads_fn:
    grads = transform_grads_fn(grads)

  # Create gradient updates.
  grad_updates = optimizer.apply_gradients(grads, global_step=global_step)

  with tf.name_scope('train_op'):
    # Make sure total_loss is valid.
    if check_numerics:
      total_loss = tf.check_numerics(total_loss,
                                     'LossTensor is inf or nan')

    # Ensure the train_tensor computes grad_updates.
    with tf.control_dependencies([grad_updates]):
      train_op = tf.identity(total_loss)

  # Add the operation used for training to the 'train_op' collection
  train_ops = tf.get_collection_ref(tf.GraphKeys.TRAIN_OP)
  if train_op not in train_ops:
    train_ops.append(train_op)

  return train_op
Exemple #27
0
    def test_rnn_multi(self):
        """Test automatic registration on a static RNN.

    The model tested here is designed for MNIST classification. To classify
    images using a recurrent neural network, we consider every image row as a
    sequence of pixels. Because MNIST image shape is 28*28px, we will then
    handle 28 sequences of 28 steps for every sample.
    """
        with tf.Graph().as_default():
            dtype = tf.float32
            n_input = 28  # MNIST data input (img shape: 28*28)
            n_timesteps = 28  # timesteps
            n_hidden = 128  # hidden layer num of features
            n_classes = 10  # MNIST total classes (0-9 digits)

            x = tf.placeholder(dtype, [None, n_timesteps, n_input])
            y = tf.placeholder(tf.int32, [None])
            x_unstack = tf.unstack(x, n_timesteps, 1)

            w_input = tf.get_variable('w_input',
                                      shape=[n_input, n_hidden],
                                      dtype=dtype)
            b_input = tf.get_variable('b_input', shape=[n_hidden], dtype=dtype)

            w_recurrent = tf.get_variable('w_recurrent',
                                          shape=[n_hidden, n_hidden],
                                          dtype=dtype)
            b_recurrent = tf.get_variable('b_recurrent',
                                          shape=[n_hidden],
                                          dtype=dtype)

            w_output = tf.get_variable('w_output',
                                       shape=[n_hidden, n_classes],
                                       dtype=dtype)
            b_output = tf.get_variable('b_output',
                                       shape=[n_classes],
                                       dtype=dtype)

            layer_collection_manual = lc.LayerCollection()
            layer_collection_auto = lc.LayerCollection()

            a = tf.zeros(tf.convert_to_tensor(
                [tf.shape(x_unstack[0])[0], n_hidden]),
                         dtype=dtype)

            # Here 'a' are the activations, 's' the pre-activations.
            a_list = [a]
            s_input_list = []
            s_recurrent_list = []
            s_list = []
            s_out_list = []
            cost = 0.0

            for i in range(len(x_unstack)):
                input_ = x_unstack[i]

                s_in = tf.matmul(input_, w_input) + b_input
                s_rec = tf.matmul(a, w_recurrent) + b_recurrent
                s = s_in + s_rec

                s_input_list.append(s_in)
                s_recurrent_list.append(s_rec)
                s_list.append(s)

                a = tf.tanh(s)
                a_list.append(a)

                s_out = tf.matmul(a, w_output) + b_output
                s_out_list.append(s_out)

                if i == len(x_unstack) - 1:
                    labels = y
                else:
                    labels = tf.zeros([tf.shape(y)[0]], dtype=tf.int32)

                cost += tf.reduce_mean(
                    tf.nn.sparse_softmax_cross_entropy_with_logits(
                        logits=s_out, labels=labels))

                layer_collection_manual.register_categorical_predictive_distribution(
                    s_out)
                layer_collection_auto.register_categorical_predictive_distribution(
                    s_out)

            layer_collection_manual.register_fully_connected_multi(
                (w_input, b_input), x_unstack, s_input_list)
            layer_collection_manual.register_fully_connected_multi(
                (w_recurrent, b_recurrent), a_list[:-1], s_recurrent_list)
            layer_collection_manual.register_fully_connected_multi(
                (w_output, b_output), a_list[1:], s_out_list)

            gs.register_layers(
                layer_collection_auto,
                tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))

            assert_fisher_blocks_match(self, layer_collection_manual,
                                       layer_collection_auto)
Exemple #28
0
    def test_specify_approximation(self):
        """Test specifying approximations.

    If linked parameters are identified along with an approximation, then
    that approximation should be used when registering those parameters.
    """
        with tf.Graph().as_default():
            w_0 = tf.get_variable('w_0', [10, 10])
            w_1 = tf.get_variable('w_1', [10, 10])

            b_0 = tf.get_variable('b_0', [10])
            b_1 = tf.get_variable('b_1', [10])

            x_0 = tf.placeholder(tf.float32, shape=(32, 10))
            x_1 = tf.placeholder(tf.float32, shape=(32, 10))

            pre_bias_0 = tf.matmul(x_0, w_0)
            pre_bias_1 = tf.matmul(x_1, w_1)

            out_0 = pre_bias_0 + b_0  # pylint: disable=unused-variable
            out_1 = pre_bias_1 + b_1  # pylint: disable=unused-variable

            # Group variables as affine layers.
            layer_collection = lc.LayerCollection()
            layer_collection.register_squared_error_loss(out_0)
            layer_collection.register_squared_error_loss(out_1)

            layer_collection.define_linked_parameters(
                (w_0, b_0), approximation=lc.APPROX_KRONECKER_NAME)
            layer_collection.define_linked_parameters(
                (w_1, b_1), approximation=lc.APPROX_DIAGONAL_NAME)
            gs.register_layers(layer_collection,
                               tf.get_collection_ref(
                                   tf.GraphKeys.GLOBAL_VARIABLES),
                               batch_size=32)
            self.assertIsInstance(layer_collection.fisher_blocks[(w_0, b_0)],
                                  fb.FullyConnectedKFACBasicFB)
            self.assertIsInstance(layer_collection.fisher_blocks[(w_1, b_1)],
                                  fb.FullyConnectedDiagonalFB)

            # Group variables as linear layers and generic parameters.
            layer_collection = lc.LayerCollection()
            layer_collection.register_squared_error_loss(out_0)
            layer_collection.register_squared_error_loss(out_1)

            layer_collection.define_linked_parameters(
                w_0, approximation=lc.APPROX_DIAGONAL_NAME)
            layer_collection.define_linked_parameters(
                b_0, approximation=lc.APPROX_DIAGONAL_NAME)
            layer_collection.define_linked_parameters(
                w_1, approximation=lc.APPROX_KRONECKER_NAME)
            layer_collection.define_linked_parameters(
                b_1, approximation=lc.APPROX_FULL_NAME)
            gs.register_layers(layer_collection,
                               tf.get_collection_ref(
                                   tf.GraphKeys.GLOBAL_VARIABLES),
                               batch_size=32)
            self.assertIsInstance(layer_collection.fisher_blocks[w_0],
                                  fb.FullyConnectedDiagonalFB)
            self.assertIsInstance(layer_collection.fisher_blocks[b_0],
                                  fb.NaiveDiagonalFB)
            self.assertIsInstance(layer_collection.fisher_blocks[w_1],
                                  fb.FullyConnectedKFACBasicFB)
            self.assertIsInstance(layer_collection.fisher_blocks[b_1],
                                  fb.FullFB)
Exemple #29
0
 def before_run(self, run_context):
     return tf.train.SessionRunArgs(
         fetches=tf.get_collection_ref(COLLECTION))
def LogAndSaveHParams():
    """Logs and saves the operative parameters to the graph."""
    hparams_str = gin.operative_config_str()
    logging.info("Config:\n%s", hparams_str)
    tf.get_collection_ref("operative_hparams").append(hparams_str)