def train_step(self, x): if hasattr(self, "alpha"): self.alpha = self.force_alpha with tf.GradientTape() as tape: rates, distortions = self.train_losses(x) losses = rates + self.lmbda * distortions loss = tf.math.reduce_mean(losses) variables = self.trainable_variables gradients = tape.gradient(loss, variables) self.optimizer.apply_gradients(zip(gradients, variables)) self.loss.update_state(losses) self.rate.update_state(rates) self.distortion.update_state(distortions) energy = [] size = [] for grad in gradients: if grad is None: continue energy.append(tf.reduce_sum(tf.square(tf.cast(grad, tf.float64)))) size.append(tf.cast(tf.size(grad), tf.float64)) self.grad_rms.update_state(tf.sqrt(tf.add_n(energy) / tf.add_n(size))) return { m.name: m.result() for m in [self.loss, self.rate, self.distortion, self.grad_rms] }
def test_get_variable(self): # Test the shim when using `get_variable` (and regularizers) directly class WrappedDenseLayer(variable_scope_shim.VariableScopeWrapperLayer): def __init__(self, units, *args, **kwargs): super().__init__(*args, **kwargs) self.units = units def forward_pass(self, inputs, training=None): out = inputs with tf.compat.v1.variable_scope("dense_one"): # The weights are created with a `regularizer`, # so the layer should track their regularization losses kernel = tf.compat.v1.get_variable( shape=[out.shape[-1], self.units], regularizer=regularizers.L2(), initializer=tf.compat.v1.ones_initializer(), name="kernel") bias = tf.compat.v1.get_variable( shape=[self.units,], initializer=tf.compat.v1.zeros_initializer(), name="bias") out = tf.matmul(out, kernel) out = tf.nn.bias_add(out, bias) with tf.compat.v1.variable_scope("nested_scope"): with tf.compat.v1.variable_scope("dense_two"): kernel = tf.compat.v1.get_variable( shape=[out.shape[-1], self.units], regularizer=regularizers.L2(), initializer=tf.compat.v1.ones_initializer(), name="kernel") bias = tf.compat.v1.get_variable( shape=[self.units,], initializer=tf.compat.v1.zeros_initializer(), name="bias") out = tf.matmul(out, kernel) out = tf.nn.bias_add(out, bias) return out layer = WrappedDenseLayer(10) out = layer(tf.ones(shape=(5, 5))) weights = {x.name: x for x in layer.variables} # Verify the correct output, regularization losses, + variables were made self.assertEqual(weights.keys(), {"dense_one/bias:0", "dense_one/kernel:0", "nested_scope/dense_two/bias:0", "nested_scope/dense_two/kernel:0"}) self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 50) self.assertAllEqual(tf.add_n(layer.losses), 1.5) # Verify reuse by updating the variables then re-running weights["dense_one/kernel:0"].assign(tf.ones(shape=(5, 10)) * 2) weights["nested_scope/dense_two/kernel:0"].assign( tf.ones(shape=(10, 10)) * 2) out = layer(tf.ones(shape=(5, 5))) self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 200) self.assertAllEqual(tf.add_n(layer.losses), 6)
def __call__(self, score_outputs, labels): """Computes total RPN detection loss. Computes total RPN detection loss including box and score from all levels. Args: score_outputs: an OrderDict with keys representing levels and values representing scores in [batch_size, height, width, num_anchors]. labels: the dictionary that returned from dataloader that includes groundturth targets. Returns: rpn_score_loss: a scalar tensor representing total score loss. """ with tf.name_scope('rpn_loss'): levels = sorted(score_outputs.keys()) score_losses = [] for level in levels: score_targets_l = labels['score_targets_%d' % level] score_losses.append( self._rpn_score_loss( score_outputs[level], score_targets_l, normalizer=tf.cast(self._batch_size * self._rpn_batch_size_per_im, dtype=tf.float32))) # Sums per level losses to total loss. return tf.add_n(score_losses)
def _weighted_sum(weights, list_of_states): """Computes a weighted sum of `list_of_states`. Args: weights: List of scalar tensors. list_of_states: List of states. Every element is assumed to be of the same structure of Tensors. Must be of the same length as `weights`. Returns: weighted_sum: A weighted sum of states in `list_of_states`. Has the same structure as elements of `list_of_states`. Raises: ValueError: If `list_of_states` is empty or length doesn't match `weights`. """ with tf.name_scope('weighted_sum'): if not weights: raise ValueError( '`list_of_states` and `weights` must be non-empty') if len(weights) != len(list_of_states): raise ValueError( '`weights` and `list_of_states` must have same length') for state in list_of_states: tf.nest.assert_same_structure(state, list_of_states[-1]) weights_and_states = zip(weights, list_of_states) weighted_states = [[ w * s_component for s_component in tf.nest.flatten(s) ] for w, s in weights_and_states if _possibly_nonzero(w)] list_of_components = zip( *weighted_states) # Put same components together. flat_final_state = [ tf.add_n(component) for component in list_of_components ] return tf.nest.pack_sequence_as(list_of_states[0], flat_final_state)
def eval(self, inputs, is_training=True, **kwargs): kwargs.update({'is_training': is_training}) all_extras = [] def _try_get_extra_results(layer): all_extras.append(( getattr(layer, 'extra_loss', None), getattr(layer, 'extra_result', None), )) x = inputs for layer in self.layers[:-1]: _try_set_extra_results(layer, loss=None, result=None) x = _try_call(layer, [x], kwargs) _try_get_extra_results(layer) last_layer = self.layers[-1] _try_set_extra_results(last_layer, loss=None, result=None) last_layer_eval_fn = getattr(last_layer, 'eval', None) if not (callable(last_layer_eval_fn) and callable(getattr(last_layer, 'eval_final', None))): last_layer_eval_fn = last_layer x = _try_call(last_layer_eval_fn, [x], kwargs) _try_get_extra_results(last_layer) non_none_extra_losses = [ loss for (loss, _) in all_extras if loss is not None ] sum_extra_losses_sans_last = (tf.add_n(non_none_extra_losses) if non_none_extra_losses else None) self._set_extra_loss(None) self._set_extra_result((sum_extra_losses_sans_last, all_extras)) return x, self.extra_result
def body(i, state): del i if not params: return state sum_params = tf.add_n(params) state = [s * sum_params for s in state] return state
def __call__(self, box_outputs, labels, num_positives): """Computes box detection loss. Computes total detection loss including box and class loss from all levels. Args: box_outputs: an OrderDict with keys representing levels and values representing box regression targets in [batch_size, height, width, num_anchors * 4]. labels: the dictionary that returned from dataloader that includes box groundturth targets. num_positives: number of positive examples in the minibatch. Returns: an integar tensor representing total box regression loss. """ # Sums all positives in a batch for normalization and avoids zero # num_positives_sum, which would lead to inf loss during training num_positives_sum = tf.reduce_sum(input_tensor=num_positives) + 1.0 box_losses = [] for level in box_outputs.keys(): # Onehot encoding for classification labels. box_targets_l = labels[level] box_losses.append( self.box_loss(box_outputs[level], box_targets_l, num_positives_sum)) # Sums per level losses to total loss. return tf.add_n(box_losses)
def _dist_jd_log_prob_ratio(p, x, q, y): """Distributed log-prob ratio for JDs.""" tf.nest.assert_same_structure(x, y) if p.shard_axis_name != q.shard_axis_name: raise ValueError( 'p and q must have the same shard_axis_name. ' f'Saw: p: {p}, {p.shard_axis_name}, q: {q}, {q.shard_axis_name}') def log_prob_ratio_parts_fn(x_y): x = tf.nest.map_structure(lambda part: part[0], x_y) y = tf.nest.map_structure(lambda part: part[1], x_y) p_dists = p.sample_distributions(value=x, seed=jd_lib.dummy_seed())[0] q_dists = q.sample_distributions(value=y, seed=jd_lib.dummy_seed())[0] lp_diffs = tf.nest.map_structure(log_prob_ratio.log_prob_ratio, p_dists, x, q_dists, y) return lp_diffs return tf.add_n( tf.nest.flatten( distribute_lib.make_sharded_log_prob_parts( log_prob_ratio_parts_fn, # Stack, because make_sharded_log_prob_parts expects # inputs/outputs to be 1 to 1. TODO(b/175084455): revisit this # after the distributed bijectors are done, as it is likely that # make_sharded_log_prob_parts will be adjusted then to not have # this limitation. p.get_sharded_distributions(), axis_name=p.shard_axis_name)(tf.nest.map_structure( lambda x, y: tf.stack([x, y], axis=0), x, y))))
def _dist_jd_log_prob_ratio(p, x, q, y, name=None): """Distributed log-prob ratio for JDs.""" with tf.name_scope(name or 'dist_jd_log_prob_ratio'): tf.nest.assert_same_structure(x, y) p_axis_names = p.experimental_shard_axis_names q_axis_names = q.experimental_shard_axis_names if p_axis_names != q_axis_names: raise ValueError( 'p and q must use the same sharding. ' f'Saw: p: {p}, {p_axis_names}, q: {q}, {q_axis_names}') def log_prob_ratio_parts_fn(x, y): p_dists = p.sample_distributions(value=x, seed=samplers.zeros_seed())[0] q_dists = q.sample_distributions(value=y, seed=samplers.zeros_seed())[0] # Ensure sharded distributions defer reductions. kwds = lambda a: {'reduce_over_shards': False} if a else {} return nest.map_structure_up_to( p_dists, lambda p, x, q, y, s: lp_ratio.log_prob_ratio( p, x, q, y, **kwds(s)), p_dists, x, q_dists, y, p_axis_names) return tf.add_n( tf.nest.flatten( distribute_lib.make_psum_function(log_prob_ratio_parts_fn, in_axes=(p_axis_names, p_axis_names), out_axes=p_axis_names, out_dtype=x)(x, y)))
def step_fn(inputs): """Function to run on the device.""" images, labels = inputs with tf.GradientTape() as tape: logits = self.model(images, training=True) prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_sum(prediction_loss) * ( 1.0 / self.flags_obj.batch_size) num_replicas = self.strategy.num_replicas_in_sync if self.flags_obj.single_l2_loss_op: l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.add_n([ tf.nn.l2_loss(v) for v in self.model.trainable_variables if 'bn' not in v.name ]) loss += (l2_loss / num_replicas) else: loss += (tf.reduce_sum(self.model.losses) / num_replicas) grad_utils.minimize_using_explicit_allreduce( tape, self.optimizer, loss, self.model.trainable_variables) self.train_loss.update_state(loss) self.train_accuracy.update_state(labels, logits)
def _metric_fn(labels, predictions, weights=None): """Counts the number of trainable parameters. Args: labels: Unused. predictions: Unused. weights: Unused. Returns: dict with a single string key `num_parameters` that maps to a tuple containing two int32 0-D Tensors, both containing the number of trainable parameters. """ del labels # unused del predictions # unused del weights # unused trainable = tf.compat.v1.trainable_variables() if tower_name: counted_variables = [ var for var in trainable if var.name.startswith("Phoenix/{}".format(tower_name)) ] else: counted_variables = trainable if counted_variables: parameters = tf.add_n([tf.size(input=var) for var in counted_variables]) else: parameters = tf.constant(0, dtype=tf.int32) return {"num_parameters": (parameters, parameters)}
def log_prob(*value): w, x = value sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [False, True, True], axis_name=self.axis_name) parts = sharded_log_prob_parts([w, x, data]) return tf.add_n(parts)
def _dist_jd_log_prob_ratio(p, x, q, y, name=None): """Distributed log-prob ratio for JDs.""" with tf.name_scope(name or 'dist_jd_log_prob_ratio'): tf.nest.assert_same_structure(x, y) p_axis_names = p.experimental_shard_axis_names q_axis_names = q.experimental_shard_axis_names if p_axis_names != q_axis_names: raise ValueError('p and q must use the same sharding. ' f'Saw: p: {p}, {p_axis_names}, q: {q}, {q_axis_names}') def log_prob_ratio_parts_fn(x_y): x = tf.nest.map_structure(lambda part: part[0], x_y) y = tf.nest.map_structure(lambda part: part[1], x_y) p_dists = p.sample_distributions(value=x, seed=samplers.zeros_seed())[0] q_dists = q.sample_distributions(value=y, seed=samplers.zeros_seed())[0] # Ensure sharded distributions defer reductions. kwds = lambda a: {'reduce_over_shards': False} if a else {} return nest.map_structure_up_to( p_dists, lambda p, x, q, y, s: lp_ratio.log_prob_ratio(p, x, q, y, **kwds(s)), p_dists, x, q_dists, y, p_axis_names) return tf.add_n( tf.nest.flatten( distribute_lib.make_sharded_log_prob_parts( log_prob_ratio_parts_fn, # Stack, because make_sharded_log_prob_parts expects # inputs/outputs to be 1 to 1. TODO(b/175084455): revisit this # after the distributed bijectors are done, as it is likely that # make_sharded_log_prob_parts will be adjusted then to not have # this limitation. p_axis_names)(tf.nest.map_structure( lambda x, y: tf.stack([x, y], axis=0), x, y))))
def log_prob(*value): w, x = value sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, {'w': False, 'x': True, 'data': True}, axis_name=self.axis_name) parts = sharded_log_prob_parts({'w': w, 'x': x, 'data': data}) return tf.add_n(tf.nest.flatten(parts))
def weight_decay_loss(self, l2_weight_decay, keras_model): # TODO(yeqing): Correct the filter according to cr/269707763. return l2_weight_decay * tf.add_n([ tf.nn.l2_loss(v) for v in self._keras_model.trainable_variables if 'batch_normalization' not in v.name and 'bias' not in v.name ])
def safe_sum(x, alt_value=-np.inf, name=None): """Elementwise adds list members, replacing non-finite results with alt_value. Typically the `alt_value` is chosen so the `MetropolisHastings` `TransitionKernel` always rejects the proposal. Args: x: Python `list` of `Tensors` to elementwise add. alt_value: Python scalar used to replace any elementwise sums which would otherwise be non-finite. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., "safe_sum"). Returns: safe_sum: `Tensor` representing the elementwise sum of list of `Tensor`s `x` or `alt_value` where sums are non-finite. Raises: TypeError: if `x` is not list-like. ValueError: if `x` is empty. """ with tf.name_scope(name or 'safe_sum'): if not is_list_like(x): raise TypeError('Expected list input.') if not x: raise ValueError('Input should not be empty.') in_shape = x[0].shape x = tf.add_n(x) x = tf.where(tf.math.is_finite(x), x, tf.constant(alt_value, dtype=x.dtype)) x.set_shape(x.shape.merge_with(in_shape)) return x
def __call__(self, box_outputs, labels): """Computes total RPN detection loss. Computes total RPN detection loss including box and score from all levels. Args: box_outputs: an OrderDict with keys representing levels and values representing box regression targets in [batch_size, height, width, num_anchors * 4]. labels: the dictionary that returned from dataloader that includes groundturth targets. Returns: rpn_box_loss: a scalar tensor representing total box regression loss. """ with tf.name_scope('rpn_loss'): levels = sorted(box_outputs.keys()) box_losses = [] for level in levels: box_losses.append( self._rpn_box_loss(box_outputs[level], labels[level])) # Sum per level losses to total loss. return tf.add_n(box_losses)
def get_gradients(x, y, log_batch_gradient=False, is_regularized=True): """Gets spars gradients and possibly logs some statistics.""" is_grad_regularized = gradient_regularization != 0 with tf.GradientTape(persistent=is_grad_regularized) as tape: predictions = model(x, training=True) batch_loss = loss_object(y, predictions) if is_regularized and is_grad_regularized: gradients = tape.gradient(batch_loss, trainable_vars) gradients = mask_gradients(model, gradients, trainable_vars) grad_vec = flatten_list_of_vars(gradients) batch_loss += tf.nn.l2_loss(grad_vec) * gradient_regularization # Regularization might have been disabled. reg_loss = tf.add_n(model.losses) if model.losses else 0 if is_regularized: batch_loss += reg_loss gradients = tape.gradient(batch_loss, trainable_vars) # Gradients are dense, we should mask them to ensure updates are sparse; # So is the norm calculation. gradients = mask_gradients(model, gradients, trainable_vars) # If batch gradient log it. if log_batch_gradient: tf.summary.scalar('train_batch_loss', batch_loss) tf.summary.scalar('train_batch_reg_loss', reg_loss) train_batch_accuracy.update_state(y, predictions) tf.summary.scalar('train_batch_accuracy', train_batch_accuracy.result()) train_batch_accuracy.reset_states() return gradients
def bundle_logits(self, priors_logits_specs, search_logits_specs): """Bundles the priors and the search candidate.""" assert search_logits_specs, "Cannot distill with no student model." assert len(search_logits_specs) == 1, "Search has more than one tower." if not priors_logits_specs: return DistillationLogits( train_logits_specs=search_logits_specs, eval_logits_spec=search_logits_specs[0], teacher_logits_spec=None) with tf.compat.v1.variable_scope("Phoenix/Distiller"): priors_logits = tf.add_n( [tf.stop_gradient(spec.logits) for spec in priors_logits_specs]) assert self._distillation_spec.distillation_type, ( "Invalid DistillationType specified.") if (self._distillation_spec.distillation_type == distillation_spec_pb2.DistillationSpec.DistillationType.MSE_LOGITS): transformed_logits = priors_logits else: transformed_logits = tf.nn.softmax(priors_logits / self._distillation_spec.temperature) transformed_logits_specs = architecture_utils.LogitsSpec( logits=transformed_logits) # Use the logits from the student model (search) to train and evaluate, # but store the logits from the teacher model (combined priors) to # calculate the loss. return DistillationLogits( train_logits_specs=search_logits_specs, eval_logits_spec=search_logits_specs[0], teacher_logits_spec=transformed_logits_specs)
def _shake_shake_block(layer_input, output_filters, stride, weight_decay, tag=""): """Builds a full Shake-Shake sub layer made of Shake-Shake branches. Args: layer_input: Input Keras layer. output_filters: Defines the number of output filters of the layer. stride: Defines the stride of the shake shake layer block. tag: String. Name tag for this shake shake block. Returns: A Shake-Shake Keras layer block. """ batch_size = tf.shape(layer_input)[0] rand_forward = [ # pylint: disable=g-complex-comprehension tf.random.uniform([batch_size, 1, 1, 1], minval=0, maxval=1, dtype=tf.float32, name="{}_1_{}".format(tag, i)) for i in range(2) ] rand_backward = [ # pylint: disable=g-complex-comprehension tf.random.uniform([batch_size, 1, 1, 1], minval=0, maxval=1, dtype=tf.float32, name="{}_2_{}".format(tag, i)) for i in range(2) ] total_forward = tf.add_n(rand_forward) total_backward = tf.add_n(rand_backward) rand_forward = [samp / total_forward for samp in rand_forward] rand_backward = [samp / total_backward for samp in rand_backward] zipped_rand = zip(rand_forward, rand_backward) branches = [] for _, (r_forward, r_backward) in enumerate(zipped_rand): b = _shake_shake_branch(layer_input, output_filters, stride, r_forward, r_backward, weight_decay) branches.append(b) res = _shake_shake_skip_connection(layer_input, output_filters, stride, weight_decay) return res + tf.add_n(branches)
def log_prob(x, y, z): sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [ self.axis_name, other_axis_name, [self.axis_name, other_axis_name] ]) parts = sharded_log_prob_parts([x, y, z]) return tf.add_n(parts)
def kinetic_energy_fn(*args, **kwargs): def one_component(x): return tf.reduce_sum(tf.square(x), axis=tf.range(chain_ndims, tf.rank(x))) return (tf.add_n( [one_component(x) for x in tf.nest.flatten([args, kwargs])]) / 2.), ()
def test_compat_v1_layer(self): # Test the shim when using `compat.v1` layers class WrappedDenseLayer(variable_scope_shim.VariableScopeWrapperLayer): def __init__(self, units, *args, **kwargs): super().__init__(*args, **kwargs) self.units = units def forward_pass(self, inputs, training=None): out = core_layers.dense( inputs, self.units, name="dense_one", kernel_initializer=tf.compat.v1.ones_initializer(), kernel_regularizer="l2") with tf.compat.v1.variable_scope("nested_scope"): out = core_layers.dense( out, self.units, name="dense_two", kernel_initializer=tf.compat.v1.ones_initializer(), kernel_regularizer="l2") return out layer = WrappedDenseLayer(10) out = layer(tf.ones(shape=(5, 5))) weights = {x.name: x for x in layer.variables} # Verify the correct output, losses, + variables were made self.assertEqual( weights.keys(), { "dense_one/bias:0", "dense_one/kernel:0", "nested_scope/dense_two/bias:0", "nested_scope/dense_two/kernel:0" }) self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 50) self.assertAllEqual(tf.add_n(layer.losses), 1.5) # Verify reuse by updating the variables then re-running weights["dense_one/kernel:0"].assign(tf.ones(shape=(5, 10)) * 2) weights["nested_scope/dense_two/kernel:0"].assign( tf.ones(shape=(10, 10)) * 2) out = layer(tf.ones(shape=(5, 5))) self.assertAllEqual(out, tf.ones(shape=(5, 10)) * 200) self.assertAllEqual(tf.add_n(layer.losses), 6)
def weight_decay_loss(self, trainable_variables): reg_variables = [ v for v in trainable_variables if self._regularization_var_regex is None or re.match(self._regularization_var_regex, v.name) ] return self._l2_weight_decay * tf.add_n( [tf.nn.l2_loss(v) for v in reg_variables])
def _jd_log_prob_ratio(p, x, q, y): tf.nest.assert_same_structure(x, y) ps, _ = p.sample_distributions(value=x) qs, _ = q.sample_distributions(value=y) tf.nest.assert_same_structure(ps, qs) parts = [] for p_, x_, q_, y_ in zip(ps, x, qs, y): parts.append(log_prob_ratio.log_prob_ratio(p_, x_, q_, y_)) return tf.add_n(parts)
def _mean(self): distribution_means = [d.mean() for d in self.components] cat_probs = self._cat_probs(log_probs=False) cat_probs = [self._expand_to_event_rank(c_p) for c_p in cat_probs] partial_means = [ c_p * m for (c_p, m) in zip(cat_probs, distribution_means) ] # These should all be the same shape by virtue of matching # batch_shape and event_shape. return tf.add_n(partial_means)
def nest_rms_norm(nest): """Computes root mean squared norm of nested structure of `Tensor`s. Args: nest: Possibly nested structure of `Tensor`s of which RMS norm is computed. Returns: norm: Scalar floating tensor equal to the RMS norm of `nest. """ sizes = tf.nest.map_structure(tf.size, nest) num_elements = tf.add_n(tf.nest.flatten(sizes)) def averaged_sum_squares(input_tensor): num_elements_cast = tf.cast(num_elements, dtype=dtype_util.real_dtype( input_tensor.dtype)) return tf.reduce_sum(abs_square(input_tensor)) / num_elements_cast squared_sums = tf.nest.map_structure(averaged_sum_squares, nest) norm = tf.math.sqrt(tf.add_n(tf.nest.flatten(squared_sums))) return norm
def _ildj_ratio_chain(p, x, q, y): """Sum-of-diffs ILDJRatio for Chains.""" if len(p.bijectors) != len(q.bijectors): raise ValueError('Mismatched lengths of bijectors: `p` has ' f'{len(p.bijectors)} but `q` has {len(q.bijectors)}.') ratios = [] for p, q in zip(p.bijectors, q.bijectors): ratios.append(ldj_ratio.inverse_log_det_jacobian_ratio( p, x, q, y, p.inverse_min_event_ndims)) x, y = p.inverse(x), q.inverse(y) return tf.add_n(ratios)
def _jd_log_prob_ratio(p, x, q, y, name=None): """Implements `log_prob_ratio` for tfd.JointDistribution*.""" with tf.name_scope(name or 'jd_log_prob_ratio'): tf.nest.assert_same_structure(x, y) ps, _ = p.sample_distributions(value=x, seed=dummy_seed()) qs, _ = q.sample_distributions(value=y, seed=dummy_seed()) tf.nest.assert_same_structure(ps, qs) parts = [] for p_, x_, q_, y_ in zip(ps, x, qs, y): parts.append(log_prob_ratio.log_prob_ratio(p_, x_, q_, y_)) return tf.add_n(parts)
def add_weight_decay(model): # Weight decay are taking care of by optimizer for these cases. # Except for supervised head, which will be added here. l2_losses = [ tf.nn.l2_loss(v) for v in model.trainable_variables if 'head_supervised' in v.name and 'bias' not in v.name ] if l2_losses: return FLAGS.weight_decay * tf.add_n(l2_losses) else: return 0