Exemplo n.º 1
0
 def train_step(self, x):
     if hasattr(self, "alpha"):
         self.alpha = self.force_alpha
     with tf.GradientTape() as tape:
         rates, distortions = self.train_losses(x)
         losses = rates + self.lmbda * distortions
         loss = tf.math.reduce_mean(losses)
     variables = self.trainable_variables
     gradients = tape.gradient(loss, variables)
     self.optimizer.apply_gradients(zip(gradients, variables))
     self.loss.update_state(losses)
     self.rate.update_state(rates)
     self.distortion.update_state(distortions)
     energy = []
     size = []
     for grad in gradients:
         if grad is None:
             continue
         energy.append(tf.reduce_sum(tf.square(tf.cast(grad, tf.float64))))
         size.append(tf.cast(tf.size(grad), tf.float64))
     self.grad_rms.update_state(tf.sqrt(tf.add_n(energy) / tf.add_n(size)))
     return {
         m.name: m.result()
         for m in [self.loss, self.rate, self.distortion, self.grad_rms]
     }
Exemplo n.º 2
0
  def test_get_variable(self):
    # Test the shim when using `get_variable` (and regularizers) directly

    class WrappedDenseLayer(variable_scope_shim.VariableScopeWrapperLayer):

      def __init__(self, units, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.units = units

      def forward_pass(self, inputs, training=None):
        out = inputs
        with tf.compat.v1.variable_scope("dense_one"):
          # The weights are created with a `regularizer`,
          # so the layer should track their regularization losses
          kernel = tf.compat.v1.get_variable(
              shape=[out.shape[-1], self.units],
              regularizer=regularizers.L2(),
              initializer=tf.compat.v1.ones_initializer(),
              name="kernel")
          bias = tf.compat.v1.get_variable(
              shape=[self.units,],
              initializer=tf.compat.v1.zeros_initializer(),
              name="bias")
          out = tf.matmul(out, kernel)
          out = tf.nn.bias_add(out, bias)
        with tf.compat.v1.variable_scope("nested_scope"):
          with tf.compat.v1.variable_scope("dense_two"):
            kernel = tf.compat.v1.get_variable(
                shape=[out.shape[-1], self.units],
                regularizer=regularizers.L2(),
                initializer=tf.compat.v1.ones_initializer(),
                name="kernel")
            bias = tf.compat.v1.get_variable(
                shape=[self.units,],
                initializer=tf.compat.v1.zeros_initializer(),
                name="bias")
            out = tf.matmul(out, kernel)
            out = tf.nn.bias_add(out, bias)
        return out

    layer = WrappedDenseLayer(10)
    out = layer(tf.ones(shape=(5, 5)))
    weights = {x.name: x for x in layer.variables}

    # Verify the correct output, regularization losses, + variables were made
    self.assertEqual(weights.keys(), {"dense_one/bias:0",
                                      "dense_one/kernel:0",
                                      "nested_scope/dense_two/bias:0",
                                      "nested_scope/dense_two/kernel:0"})
    self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 50)
    self.assertAllEqual(tf.add_n(layer.losses), 1.5)

    # Verify reuse by updating the variables then re-running
    weights["dense_one/kernel:0"].assign(tf.ones(shape=(5, 10)) * 2)
    weights["nested_scope/dense_two/kernel:0"].assign(
        tf.ones(shape=(10, 10)) * 2)
    out = layer(tf.ones(shape=(5, 5)))
    self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 200)
    self.assertAllEqual(tf.add_n(layer.losses), 6)
Exemplo n.º 3
0
    def __call__(self, score_outputs, labels):
        """Computes total RPN detection loss.

    Computes total RPN detection loss including box and score from all levels.
    Args:
      score_outputs: an OrderDict with keys representing levels and values
        representing scores in [batch_size, height, width, num_anchors].
      labels: the dictionary that returned from dataloader that includes
        groundturth targets.
    Returns:
      rpn_score_loss: a scalar tensor representing total score loss.
    """
        with tf.name_scope('rpn_loss'):
            levels = sorted(score_outputs.keys())

            score_losses = []
            for level in levels:
                score_targets_l = labels['score_targets_%d' % level]
                score_losses.append(
                    self._rpn_score_loss(
                        score_outputs[level],
                        score_targets_l,
                        normalizer=tf.cast(self._batch_size *
                                           self._rpn_batch_size_per_im,
                                           dtype=tf.float32)))

            # Sums per level losses to total loss.
            return tf.add_n(score_losses)
Exemplo n.º 4
0
def _weighted_sum(weights, list_of_states):
    """Computes a weighted sum of `list_of_states`.

  Args:
    weights: List of scalar tensors.
    list_of_states: List of states. Every element is assumed to be of the same
      structure of Tensors. Must be of the same length as `weights`.

  Returns:
    weighted_sum: A weighted sum of states in `list_of_states`. Has the same
      structure as elements of `list_of_states`.

  Raises:
    ValueError: If `list_of_states` is empty or length doesn't match `weights`.
  """
    with tf.name_scope('weighted_sum'):
        if not weights:
            raise ValueError(
                '`list_of_states` and `weights` must be non-empty')
        if len(weights) != len(list_of_states):
            raise ValueError(
                '`weights` and `list_of_states` must have same length')
        for state in list_of_states:
            tf.nest.assert_same_structure(state, list_of_states[-1])
        weights_and_states = zip(weights, list_of_states)
        weighted_states = [[
            w * s_component for s_component in tf.nest.flatten(s)
        ] for w, s in weights_and_states if _possibly_nonzero(w)]
        list_of_components = zip(
            *weighted_states)  # Put same components together.
        flat_final_state = [
            tf.add_n(component) for component in list_of_components
        ]
        return tf.nest.pack_sequence_as(list_of_states[0], flat_final_state)
Exemplo n.º 5
0
    def eval(self, inputs, is_training=True, **kwargs):
        kwargs.update({'is_training': is_training})
        all_extras = []

        def _try_get_extra_results(layer):
            all_extras.append((
                getattr(layer, 'extra_loss', None),
                getattr(layer, 'extra_result', None),
            ))

        x = inputs
        for layer in self.layers[:-1]:
            _try_set_extra_results(layer, loss=None, result=None)
            x = _try_call(layer, [x], kwargs)
            _try_get_extra_results(layer)

        last_layer = self.layers[-1]
        _try_set_extra_results(last_layer, loss=None, result=None)
        last_layer_eval_fn = getattr(last_layer, 'eval', None)
        if not (callable(last_layer_eval_fn)
                and callable(getattr(last_layer, 'eval_final', None))):
            last_layer_eval_fn = last_layer
        x = _try_call(last_layer_eval_fn, [x], kwargs)
        _try_get_extra_results(last_layer)

        non_none_extra_losses = [
            loss for (loss, _) in all_extras if loss is not None
        ]
        sum_extra_losses_sans_last = (tf.add_n(non_none_extra_losses)
                                      if non_none_extra_losses else None)
        self._set_extra_loss(None)
        self._set_extra_result((sum_extra_losses_sans_last, all_extras))
        return x, self.extra_result
Exemplo n.º 6
0
 def body(i, state):
     del i
     if not params:
         return state
     sum_params = tf.add_n(params)
     state = [s * sum_params for s in state]
     return state
    def __call__(self, box_outputs, labels, num_positives):
        """Computes box detection loss.

    Computes total detection loss including box and class loss from all levels.

    Args:
      box_outputs: an OrderDict with keys representing levels and values
        representing box regression targets in [batch_size, height, width,
        num_anchors * 4].
      labels: the dictionary that returned from dataloader that includes
        box groundturth targets.
      num_positives: number of positive examples in the minibatch.

    Returns:
      an integar tensor representing total box regression loss.
    """
        # Sums all positives in a batch for normalization and avoids zero
        # num_positives_sum, which would lead to inf loss during training
        num_positives_sum = tf.reduce_sum(input_tensor=num_positives) + 1.0

        box_losses = []
        for level in box_outputs.keys():
            # Onehot encoding for classification labels.
            box_targets_l = labels[level]
            box_losses.append(
                self.box_loss(box_outputs[level], box_targets_l,
                              num_positives_sum))
        # Sums per level losses to total loss.
        return tf.add_n(box_losses)
Exemplo n.º 8
0
def _dist_jd_log_prob_ratio(p, x, q, y):
    """Distributed log-prob ratio for JDs."""
    tf.nest.assert_same_structure(x, y)
    if p.shard_axis_name != q.shard_axis_name:
        raise ValueError(
            'p and q must have the same shard_axis_name. '
            f'Saw: p: {p}, {p.shard_axis_name}, q: {q}, {q.shard_axis_name}')

    def log_prob_ratio_parts_fn(x_y):
        x = tf.nest.map_structure(lambda part: part[0], x_y)
        y = tf.nest.map_structure(lambda part: part[1], x_y)
        p_dists = p.sample_distributions(value=x, seed=jd_lib.dummy_seed())[0]
        q_dists = q.sample_distributions(value=y, seed=jd_lib.dummy_seed())[0]
        lp_diffs = tf.nest.map_structure(log_prob_ratio.log_prob_ratio,
                                         p_dists, x, q_dists, y)
        return lp_diffs

    return tf.add_n(
        tf.nest.flatten(
            distribute_lib.make_sharded_log_prob_parts(
                log_prob_ratio_parts_fn,
                # Stack, because make_sharded_log_prob_parts expects
                # inputs/outputs to be 1 to 1. TODO(b/175084455): revisit this
                # after the distributed bijectors are done, as it is likely that
                # make_sharded_log_prob_parts will be adjusted then to not have
                # this limitation.
                p.get_sharded_distributions(),
                axis_name=p.shard_axis_name)(tf.nest.map_structure(
                    lambda x, y: tf.stack([x, y], axis=0), x, y))))
Exemplo n.º 9
0
def _dist_jd_log_prob_ratio(p, x, q, y, name=None):
    """Distributed log-prob ratio for JDs."""
    with tf.name_scope(name or 'dist_jd_log_prob_ratio'):
        tf.nest.assert_same_structure(x, y)

        p_axis_names = p.experimental_shard_axis_names
        q_axis_names = q.experimental_shard_axis_names
        if p_axis_names != q_axis_names:
            raise ValueError(
                'p and q must use the same sharding. '
                f'Saw: p: {p}, {p_axis_names}, q: {q}, {q_axis_names}')

        def log_prob_ratio_parts_fn(x, y):
            p_dists = p.sample_distributions(value=x,
                                             seed=samplers.zeros_seed())[0]
            q_dists = q.sample_distributions(value=y,
                                             seed=samplers.zeros_seed())[0]
            # Ensure sharded distributions defer reductions.
            kwds = lambda a: {'reduce_over_shards': False} if a else {}
            return nest.map_structure_up_to(
                p_dists, lambda p, x, q, y, s: lp_ratio.log_prob_ratio(
                    p, x, q, y, **kwds(s)), p_dists, x, q_dists, y,
                p_axis_names)

        return tf.add_n(
            tf.nest.flatten(
                distribute_lib.make_psum_function(log_prob_ratio_parts_fn,
                                                  in_axes=(p_axis_names,
                                                           p_axis_names),
                                                  out_axes=p_axis_names,
                                                  out_dtype=x)(x, y)))
Exemplo n.º 10
0
        def step_fn(inputs):
            """Function to run on the device."""
            images, labels = inputs
            with tf.GradientTape() as tape:
                logits = self.model(images, training=True)

                prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, logits)
                loss = tf.reduce_sum(prediction_loss) * (
                    1.0 / self.flags_obj.batch_size)
                num_replicas = self.strategy.num_replicas_in_sync

                if self.flags_obj.single_l2_loss_op:
                    l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.add_n([
                        tf.nn.l2_loss(v)
                        for v in self.model.trainable_variables
                        if 'bn' not in v.name
                    ])

                    loss += (l2_loss / num_replicas)
                else:
                    loss += (tf.reduce_sum(self.model.losses) / num_replicas)

            grad_utils.minimize_using_explicit_allreduce(
                tape, self.optimizer, loss, self.model.trainable_variables)
            self.train_loss.update_state(loss)
            self.train_accuracy.update_state(labels, logits)
Exemplo n.º 11
0
  def _metric_fn(labels, predictions, weights=None):
    """Counts the number of trainable parameters.

    Args:
      labels: Unused.
      predictions: Unused.
      weights: Unused.

    Returns:
      dict with a single string key `num_parameters` that maps to a tuple
      containing two int32 0-D Tensors, both containing the number of trainable
      parameters.
    """

    del labels  # unused
    del predictions  # unused
    del weights  # unused

    trainable = tf.compat.v1.trainable_variables()
    if tower_name:
      counted_variables = [
          var for var in trainable
          if var.name.startswith("Phoenix/{}".format(tower_name))
      ]
    else:
      counted_variables = trainable

    if counted_variables:
      parameters = tf.add_n([tf.size(input=var) for var in counted_variables])
    else:
      parameters = tf.constant(0, dtype=tf.int32)

    return {"num_parameters": (parameters, parameters)}
Exemplo n.º 12
0
 def log_prob(*value):
     w, x = value
     sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
         log_prob_parts, [False, True, True],
         axis_name=self.axis_name)
     parts = sharded_log_prob_parts([w, x, data])
     return tf.add_n(parts)
Exemplo n.º 13
0
def _dist_jd_log_prob_ratio(p, x, q, y, name=None):
  """Distributed log-prob ratio for JDs."""
  with tf.name_scope(name or 'dist_jd_log_prob_ratio'):
    tf.nest.assert_same_structure(x, y)

    p_axis_names = p.experimental_shard_axis_names
    q_axis_names = q.experimental_shard_axis_names
    if p_axis_names != q_axis_names:
      raise ValueError('p and q must use the same sharding. '
                       f'Saw: p: {p}, {p_axis_names}, q: {q}, {q_axis_names}')

    def log_prob_ratio_parts_fn(x_y):
      x = tf.nest.map_structure(lambda part: part[0], x_y)
      y = tf.nest.map_structure(lambda part: part[1], x_y)
      p_dists = p.sample_distributions(value=x, seed=samplers.zeros_seed())[0]
      q_dists = q.sample_distributions(value=y, seed=samplers.zeros_seed())[0]
      # Ensure sharded distributions defer reductions.
      kwds = lambda a: {'reduce_over_shards': False} if a else {}
      return nest.map_structure_up_to(
          p_dists,
          lambda p, x, q, y, s: lp_ratio.log_prob_ratio(p, x, q, y, **kwds(s)),
          p_dists, x, q_dists, y, p_axis_names)

    return tf.add_n(
        tf.nest.flatten(
            distribute_lib.make_sharded_log_prob_parts(
                log_prob_ratio_parts_fn,
                # Stack, because make_sharded_log_prob_parts expects
                # inputs/outputs to be 1 to 1. TODO(b/175084455): revisit this
                # after the distributed bijectors are done, as it is likely that
                # make_sharded_log_prob_parts will be adjusted then to not have
                # this limitation.
                p_axis_names)(tf.nest.map_structure(
                    lambda x, y: tf.stack([x, y], axis=0), x, y))))
Exemplo n.º 14
0
 def log_prob(*value):
   w, x = value
   sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
       log_prob_parts, {'w': False, 'x': True, 'data': True},
       axis_name=self.axis_name)
   parts = sharded_log_prob_parts({'w': w, 'x': x, 'data': data})
   return tf.add_n(tf.nest.flatten(parts))
Exemplo n.º 15
0
 def weight_decay_loss(self, l2_weight_decay, keras_model):
   # TODO(yeqing): Correct the filter according to  cr/269707763.
   return l2_weight_decay * tf.add_n([
       tf.nn.l2_loss(v)
       for v in self._keras_model.trainable_variables
       if 'batch_normalization' not in v.name and 'bias' not in v.name
   ])
Exemplo n.º 16
0
def safe_sum(x, alt_value=-np.inf, name=None):
    """Elementwise adds list members, replacing non-finite results with alt_value.

  Typically the `alt_value` is chosen so the `MetropolisHastings`
  `TransitionKernel` always rejects the proposal.

  Args:
    x: Python `list` of `Tensors` to elementwise add.
    alt_value: Python scalar used to replace any elementwise sums which would
      otherwise be non-finite.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., "safe_sum").

  Returns:
    safe_sum: `Tensor` representing the elementwise sum of list of `Tensor`s
      `x` or `alt_value` where sums are non-finite.

  Raises:
    TypeError: if `x` is not list-like.
    ValueError: if `x` is empty.
  """
    with tf.name_scope(name or 'safe_sum'):
        if not is_list_like(x):
            raise TypeError('Expected list input.')
        if not x:
            raise ValueError('Input should not be empty.')
        in_shape = x[0].shape
        x = tf.add_n(x)
        x = tf.where(tf.math.is_finite(x), x,
                     tf.constant(alt_value, dtype=x.dtype))
        x.set_shape(x.shape.merge_with(in_shape))
        return x
    def __call__(self, box_outputs, labels):
        """Computes total RPN detection loss.

    Computes total RPN detection loss including box and score from all levels.

    Args:
      box_outputs: an OrderDict with keys representing levels and values
        representing box regression targets in
        [batch_size, height, width, num_anchors * 4].
      labels: the dictionary that returned from dataloader that includes
        groundturth targets.

    Returns:
      rpn_box_loss: a scalar tensor representing total box regression loss.
    """
        with tf.name_scope('rpn_loss'):
            levels = sorted(box_outputs.keys())

            box_losses = []
            for level in levels:
                box_losses.append(
                    self._rpn_box_loss(box_outputs[level], labels[level]))

            # Sum per level losses to total loss.
            return tf.add_n(box_losses)
Exemplo n.º 18
0
 def get_gradients(x, y, log_batch_gradient=False, is_regularized=True):
   """Gets spars gradients and possibly logs some statistics."""
   is_grad_regularized = gradient_regularization != 0
   with tf.GradientTape(persistent=is_grad_regularized) as tape:
     predictions = model(x, training=True)
     batch_loss = loss_object(y, predictions)
     if is_regularized and is_grad_regularized:
       gradients = tape.gradient(batch_loss, trainable_vars)
       gradients = mask_gradients(model, gradients, trainable_vars)
       grad_vec = flatten_list_of_vars(gradients)
       batch_loss += tf.nn.l2_loss(grad_vec) * gradient_regularization
     # Regularization might have been disabled.
     reg_loss = tf.add_n(model.losses) if model.losses else 0
     if is_regularized:
       batch_loss += reg_loss
   gradients = tape.gradient(batch_loss, trainable_vars)
   # Gradients are dense, we should mask them to ensure updates are sparse;
   # So is the norm calculation.
   gradients = mask_gradients(model, gradients, trainable_vars)
   # If batch gradient log it.
   if log_batch_gradient:
     tf.summary.scalar('train_batch_loss', batch_loss)
     tf.summary.scalar('train_batch_reg_loss', reg_loss)
     train_batch_accuracy.update_state(y, predictions)
     tf.summary.scalar('train_batch_accuracy', train_batch_accuracy.result())
     train_batch_accuracy.reset_states()
   return gradients
Exemplo n.º 19
0
  def bundle_logits(self, priors_logits_specs, search_logits_specs):
    """Bundles the priors and the search candidate."""

    assert search_logits_specs, "Cannot distill with no student model."
    assert len(search_logits_specs) == 1, "Search has more than one tower."

    if not priors_logits_specs:
      return DistillationLogits(
          train_logits_specs=search_logits_specs,
          eval_logits_spec=search_logits_specs[0],
          teacher_logits_spec=None)

    with tf.compat.v1.variable_scope("Phoenix/Distiller"):
      priors_logits = tf.add_n(
          [tf.stop_gradient(spec.logits) for spec in priors_logits_specs])

      assert self._distillation_spec.distillation_type, (
          "Invalid DistillationType specified.")
      if (self._distillation_spec.distillation_type ==
          distillation_spec_pb2.DistillationSpec.DistillationType.MSE_LOGITS):
        transformed_logits = priors_logits
      else:
        transformed_logits = tf.nn.softmax(priors_logits /
                                           self._distillation_spec.temperature)

      transformed_logits_specs = architecture_utils.LogitsSpec(
          logits=transformed_logits)

      # Use the logits from the student model (search) to train and evaluate,
      # but store the logits from the teacher model (combined priors) to
      # calculate the loss.
      return DistillationLogits(
          train_logits_specs=search_logits_specs,
          eval_logits_spec=search_logits_specs[0],
          teacher_logits_spec=transformed_logits_specs)
Exemplo n.º 20
0
def _shake_shake_block(layer_input,
                       output_filters,
                       stride,
                       weight_decay,
                       tag=""):
    """Builds a full Shake-Shake sub layer made of Shake-Shake branches.

  Args:
    layer_input: Input Keras layer.
    output_filters: Defines the number of output filters of the layer.
    stride: Defines the stride of the shake shake layer block.
    tag: String. Name tag for this shake shake block.

  Returns:
    A Shake-Shake Keras layer block.
  """
    batch_size = tf.shape(layer_input)[0]
    rand_forward = [
        # pylint: disable=g-complex-comprehension
        tf.random.uniform([batch_size, 1, 1, 1],
                          minval=0,
                          maxval=1,
                          dtype=tf.float32,
                          name="{}_1_{}".format(tag, i)) for i in range(2)
    ]
    rand_backward = [
        # pylint: disable=g-complex-comprehension
        tf.random.uniform([batch_size, 1, 1, 1],
                          minval=0,
                          maxval=1,
                          dtype=tf.float32,
                          name="{}_2_{}".format(tag, i)) for i in range(2)
    ]

    total_forward = tf.add_n(rand_forward)
    total_backward = tf.add_n(rand_backward)
    rand_forward = [samp / total_forward for samp in rand_forward]
    rand_backward = [samp / total_backward for samp in rand_backward]
    zipped_rand = zip(rand_forward, rand_backward)
    branches = []
    for _, (r_forward, r_backward) in enumerate(zipped_rand):
        b = _shake_shake_branch(layer_input, output_filters, stride, r_forward,
                                r_backward, weight_decay)
        branches.append(b)
    res = _shake_shake_skip_connection(layer_input, output_filters, stride,
                                       weight_decay)
    return res + tf.add_n(branches)
Exemplo n.º 21
0
 def log_prob(x, y, z):
     sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
         log_prob_parts, [
             self.axis_name, other_axis_name,
             [self.axis_name, other_axis_name]
         ])
     parts = sharded_log_prob_parts([x, y, z])
     return tf.add_n(parts)
Exemplo n.º 22
0
    def kinetic_energy_fn(*args, **kwargs):
        def one_component(x):
            return tf.reduce_sum(tf.square(x),
                                 axis=tf.range(chain_ndims, tf.rank(x)))

        return (tf.add_n(
            [one_component(x)
             for x in tf.nest.flatten([args, kwargs])]) / 2.), ()
Exemplo n.º 23
0
    def test_compat_v1_layer(self):
        # Test the shim when using `compat.v1` layers

        class WrappedDenseLayer(variable_scope_shim.VariableScopeWrapperLayer):
            def __init__(self, units, *args, **kwargs):
                super().__init__(*args, **kwargs)
                self.units = units

            def forward_pass(self, inputs, training=None):
                out = core_layers.dense(
                    inputs,
                    self.units,
                    name="dense_one",
                    kernel_initializer=tf.compat.v1.ones_initializer(),
                    kernel_regularizer="l2")
                with tf.compat.v1.variable_scope("nested_scope"):
                    out = core_layers.dense(
                        out,
                        self.units,
                        name="dense_two",
                        kernel_initializer=tf.compat.v1.ones_initializer(),
                        kernel_regularizer="l2")
                return out

        layer = WrappedDenseLayer(10)
        out = layer(tf.ones(shape=(5, 5)))
        weights = {x.name: x for x in layer.variables}

        # Verify the correct output, losses, + variables were made
        self.assertEqual(
            weights.keys(), {
                "dense_one/bias:0", "dense_one/kernel:0",
                "nested_scope/dense_two/bias:0",
                "nested_scope/dense_two/kernel:0"
            })
        self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 50)
        self.assertAllEqual(tf.add_n(layer.losses), 1.5)

        # Verify reuse by updating the variables then re-running
        weights["dense_one/kernel:0"].assign(tf.ones(shape=(5, 10)) * 2)
        weights["nested_scope/dense_two/kernel:0"].assign(
            tf.ones(shape=(10, 10)) * 2)
        out = layer(tf.ones(shape=(5, 5)))
        self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 200)
        self.assertAllEqual(tf.add_n(layer.losses), 6)
Exemplo n.º 24
0
  def weight_decay_loss(self, trainable_variables):
    reg_variables = [
        v for v in trainable_variables
        if self._regularization_var_regex is None
        or re.match(self._regularization_var_regex, v.name)
    ]

    return self._l2_weight_decay * tf.add_n(
        [tf.nn.l2_loss(v) for v in reg_variables])
Exemplo n.º 25
0
def _jd_log_prob_ratio(p, x, q, y):
  tf.nest.assert_same_structure(x, y)
  ps, _ = p.sample_distributions(value=x)
  qs, _ = q.sample_distributions(value=y)
  tf.nest.assert_same_structure(ps, qs)
  parts = []
  for p_, x_, q_, y_ in zip(ps, x, qs, y):
    parts.append(log_prob_ratio.log_prob_ratio(p_, x_, q_, y_))
  return tf.add_n(parts)
Exemplo n.º 26
0
 def _mean(self):
     distribution_means = [d.mean() for d in self.components]
     cat_probs = self._cat_probs(log_probs=False)
     cat_probs = [self._expand_to_event_rank(c_p) for c_p in cat_probs]
     partial_means = [
         c_p * m for (c_p, m) in zip(cat_probs, distribution_means)
     ]
     # These should all be the same shape by virtue of matching
     # batch_shape and event_shape.
     return tf.add_n(partial_means)
Exemplo n.º 27
0
def nest_rms_norm(nest):
    """Computes root mean squared norm of nested structure of `Tensor`s.

  Args:
    nest: Possibly nested structure of `Tensor`s of which RMS norm is computed.
  Returns:
    norm: Scalar floating tensor equal to the RMS norm of `nest.
  """
    sizes = tf.nest.map_structure(tf.size, nest)
    num_elements = tf.add_n(tf.nest.flatten(sizes))

    def averaged_sum_squares(input_tensor):
        num_elements_cast = tf.cast(num_elements,
                                    dtype=dtype_util.real_dtype(
                                        input_tensor.dtype))
        return tf.reduce_sum(abs_square(input_tensor)) / num_elements_cast

    squared_sums = tf.nest.map_structure(averaged_sum_squares, nest)
    norm = tf.math.sqrt(tf.add_n(tf.nest.flatten(squared_sums)))
    return norm
Exemplo n.º 28
0
def _ildj_ratio_chain(p, x, q, y):
  """Sum-of-diffs ILDJRatio for Chains."""
  if len(p.bijectors) != len(q.bijectors):
    raise ValueError('Mismatched lengths of bijectors: `p` has '
                     f'{len(p.bijectors)} but `q` has {len(q.bijectors)}.')
  ratios = []
  for p, q in zip(p.bijectors, q.bijectors):
    ratios.append(ldj_ratio.inverse_log_det_jacobian_ratio(
        p, x, q, y, p.inverse_min_event_ndims))
    x, y = p.inverse(x), q.inverse(y)
  return tf.add_n(ratios)
Exemplo n.º 29
0
def _jd_log_prob_ratio(p, x, q, y, name=None):
    """Implements `log_prob_ratio` for tfd.JointDistribution*."""
    with tf.name_scope(name or 'jd_log_prob_ratio'):
        tf.nest.assert_same_structure(x, y)
        ps, _ = p.sample_distributions(value=x, seed=dummy_seed())
        qs, _ = q.sample_distributions(value=y, seed=dummy_seed())
        tf.nest.assert_same_structure(ps, qs)
        parts = []
        for p_, x_, q_, y_ in zip(ps, x, qs, y):
            parts.append(log_prob_ratio.log_prob_ratio(p_, x_, q_, y_))
        return tf.add_n(parts)
Exemplo n.º 30
0
def add_weight_decay(model):
    # Weight decay are taking care of by optimizer for these cases.
    # Except for supervised head, which will be added here.
    l2_losses = [
        tf.nn.l2_loss(v) for v in model.trainable_variables
        if 'head_supervised' in v.name and 'bias' not in v.name
    ]
    if l2_losses:
        return FLAGS.weight_decay * tf.add_n(l2_losses)
    else:
        return 0