def _build_update_ops_second_moment(self, mean, second_moment, is_training): def build_update_ops(): update_mean_op = moving_averages.assign_moving_average( variable=self._moving_mean, value=mean, decay=self._decay_rate, name="update_moving_mean").op update_second_moment_op = moving_averages.assign_moving_average( variable=self._moving_second_moment, value=second_moment, decay=self._decay_rate, name="update_moving_second_moment").op return update_mean_op, update_second_moment_op def build_no_ops(): return (tf.no_op(), tf.no_op()) is_training_const = utils.constant_value(is_training) if is_training_const is None or is_training_const: update_mean_op, update_second_moment_op = utils.smart_cond( is_training, build_update_ops, build_no_ops, ) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean_op) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_second_moment_op)
def test_value(self): fn1 = lambda: 'fn1' fn2 = lambda: 'fn2' expected = lambda v: 'fn1' if v else 'fn2' for v in [True, False, 1, 0]: o = utils.smart_cond(tf.constant(v), fn1, fn2) self.assertEqual(o, expected(v))
def _build_update_ops_variance(self, mean, variance, is_training): def build_update_ops(): update_mean_op = moving_averages.assign_moving_average( variable=self._moving_mean, value=mean, decay=self._decay_rate, name="update_moving_mean").op update_variance_op = moving_averages.assign_moving_average( variable=self._moving_variance, value=variance, decay=self._decay_rate, name="update_moving_variance").op return update_mean_op, update_variance_op def build_no_ops(): return (tf.no_op(), tf.no_op()) # Only make the ops if we know that `is_training=True`, or the # value of `is_training` is unknown. is_training_const = utils.constant_value(is_training) if is_training_const is None or is_training_const: update_mean_op, update_variance_op = utils.smart_cond( is_training, build_update_ops, build_no_ops, ) # Every new connection creates a new op which adds its contribution # to the running average when ran. tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean_op) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_variance_op)
def test_value(self): fn1 = lambda: 'fn1' fn2 = lambda: 'fn2' expected = lambda v: 'fn1' if v else 'fn2' for v in [True, False, 1, 0]: o = utils.smart_cond(constant_op.constant(v), fn1, fn2) self.assertEqual(o, expected(v))
def _update_renorm_variable(var, weight, value): """Updates a moving average and weight, returns the unbiased value.""" value = array_ops.identity(value) def _do_update(): # Update the variables without zero debiasing. The debiasing will be # accomplished by dividing the exponential moving average by the weight. # For example, after a single update, the moving average would be # (1-decay) * value. and the weight will be 1-decay, with their ratio # giving the value. # Make sure the weight is not updated until before r and d computation. with ops.control_dependencies([value]): weight_value = array_ops.constant(1., dtype=weight.dtype) new_var = moving_averages.assign_moving_average( var, value, renorm_params.renorm_momentum, zero_debias=False) new_weight = moving_averages.assign_moving_average( weight, weight_value, renorm_params.renorm_momentum, zero_debias=False) return new_var / new_weight def _fake_update(): return array_ops.identity(var) return utils.smart_cond(training, _do_update, _fake_update)
def batch_norm(input, gammas, betas, epsilon, is_training): """ BatchNorm implementation with sample-specific beta and gamma parameters i.e. the shift and scaling parameters are different across a batch of examples :param input: feature map input. 3-D vector (+ batch size) :param gammas: BN gamma parameters. 1-D vector (+ batch size) :param betas: BN betas parameters. 1-D vector (+ batch size) :param epsilon: BN epsilon for stability :param is_training: compute (True) or use (False) moving mean and variance :return: input after BN """ assert (len(input.get_shape()) == 4) num_channels = int(input.get_shape()[3]) # use cbn input score to not initialize the variable with resnet values with tf.variable_scope("cbn_input"): moving_mean = tf.get_variable("moving_mean", [num_channels], dtype=tf.float32, trainable=False) moving_variance = tf.get_variable("moving_variance", [num_channels], dtype=tf.float32, trainable=False) def _training(): """ Internal function that delay updates moving_vars if is_training. """ mean, variance = tf.nn.moments(input, [0, 1, 2]) update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, 0.99, zero_debias=True) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, 0.99, zero_debias=False) return mean, variance, update_moving_mean, update_moving_variance def _inference(): return moving_mean, moving_variance, moving_mean, moving_variance # Collect mean/variance to prepare moving mean/variance means, variances, update_mean, update_variance = tf_utils.smart_cond( is_training, _training, _inference) # Add moving mean/variance to tue update_ops (cf tensorflow batchnorm documentation) updates_collections = ops.GraphKeys.UPDATE_OPS ops.add_to_collections(updates_collections, update_mean) ops.add_to_collections(updates_collections, update_variance) # apply batch norm inv = gammas * tf.expand_dims(tf.rsqrt(variances + epsilon), 0) expanded_inv = tf.reshape(inv, [-1, 1, 1, num_channels]) expanded_mean = tf.reshape(means, [-1, 1, 1, num_channels]) expanded_betas = tf.reshape(betas, [-1, 1, 1, num_channels]) out = expanded_inv * (input - expanded_mean) + expanded_betas return out
def dropout(x, keep_prob=0.5, is_training=False, name='drop'): with tf.variable_scope(name, 'dropout', [x]) as sc: x = utils.smart_cond(is_training, lambda: tf.nn.dropout(x, keep_prob, name=name), lambda: x, name=name) return utils.collect_named_outputs(tf.GraphKeys.ACTIVATIONS, sc.original_name_scope, x)
def test_constant(self): fn1 = lambda: constant_op.constant('fn1') fn2 = lambda: constant_op.constant('fn2') expected = lambda v: b'fn1' if v else b'fn2' for v in [True, False, 1, 0]: o = utils.smart_cond(constant_op.constant(v), fn1, fn2) with self.test_session(): self.assertEqual(o.eval(), expected(v))
def test_tensors(self): fn1 = lambda: constant_op.constant(0) - constant_op.constant(1) fn2 = lambda: constant_op.constant(0) - constant_op.constant(2) expected = lambda v: -1 if v else -2 for v in [True, False, 1, 0]: o = utils.smart_cond(constant_op.constant(v), fn1, fn2) with self.test_session(): self.assertEqual(o.eval(), expected(v))
def test_constant(self): fn1 = lambda: tf.constant('fn1') fn2 = lambda: tf.constant('fn2') expected = lambda v: b'fn1' if v else b'fn2' for v in [True, False, 1, 0]: o = utils.smart_cond(tf.constant(v), fn1, fn2) with self.test_session(): self.assertEqual(o.eval(), expected(v))
def test_tensors(self): fn1 = lambda: tf.constant(0) - tf.constant(1) fn2 = lambda: tf.constant(0) - tf.constant(2) expected = lambda v: -1 if v else -2 for v in [True, False, 1, 0]: o = utils.smart_cond(tf.constant(v), fn1, fn2) with self.test_session(): self.assertEqual(o.eval(), expected(v))
def _create_softmax_layer(self, proj, dec_outs, targets, weights, scope_name, args): with tf.variable_scope(scope_name): w_t, b = proj # is_training = Flase def get_llh_test(): dec_outs_flt = tf.reshape(dec_outs, [-1, args.num_units]) logits_flt = tf.add( tf.matmul(dec_outs_flt, w_t, transpose_b=True), b[None, :]) logits = tf.reshape( logits_flt, [tf.shape(dec_outs)[0], -1, args.vocab_size]) llh_precise = tf.contrib.seq2seq.sequence_loss( logits=logits, targets=targets, weights=weights, average_across_timesteps=False, average_across_batch=False, softmax_loss_function=None) return llh_precise # is_training = True def sampled_loss(inputs, labels): labels = tf.reshape(labels, [-1, 1]) # use 32bit float to avoid numerical instabilites #w_t = tf.transpose(w) local_w_t = tf.cast(w_t, tf.float32) local_b = tf.cast(b, tf.float32) local_inputs = tf.cast(inputs, tf.float32) return tf.nn.sampled_softmax_loss(weights=local_w_t, biases=local_b, inputs=local_inputs, labels=labels, num_sampled=args.num_samples, num_classes=args.vocab_size, partition_strategy="div") # is_training = False def get_llh_train(): # if use sampled_softmax if args.use_sampled_softmax and args.num_samples > 0 and args.num_samples < args.vocab_size: llh_train = tf.contrib.seq2seq.sequence_loss( logits=dec_outs, targets=targets, weights=weights, average_across_timesteps=False, average_across_batch=False, softmax_loss_function=sampled_loss) self._logger.info('Use sampled softmax during training') else: llh_train = get_llh_test() self._logger.info('Use precise softmax during training') return llh_train loss = smart_cond(self.is_training_plh, get_llh_train, get_llh_test) return loss
def test_tensors(self): fn1 = lambda: tf.constant(0) - tf.constant(1) fn2 = lambda: tf.constant(0) - tf.constant(2) expected = lambda v: -1 if v else -2 p = tf.placeholder(tf.bool, []) for v in [True, False, 1, 0]: o = utils.smart_cond(p, fn1, fn2) with self.test_session(): self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
def test_variable(self): fn1 = lambda: variables.Variable('fn1') fn2 = lambda: variables.Variable('fn2') expected = lambda v: b'fn1' if v else b'fn2' for v in [True, False, 1, 0]: o = utils.smart_cond(constant_op.constant(v), fn1, fn2) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) self.assertEqual(o.eval(), expected(v))
def test_value(self): fn1 = lambda: ops.convert_to_tensor('fn1') fn2 = lambda: ops.convert_to_tensor('fn2') expected = lambda v: b'fn1' if v else b'fn2' p = array_ops.placeholder(dtypes.bool, []) for v in [True, False, 1, 0]: o = utils.smart_cond(p, fn1, fn2) with self.test_session(): self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
def test_variable(self): fn1 = lambda: tf.Variable('fn1') fn2 = lambda: tf.Variable('fn2') expected = lambda v: b'fn1' if v else b'fn2' for v in [True, False, 1, 0]: o = utils.smart_cond(tf.constant(v), fn1, fn2) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) self.assertEqual(o.eval(), expected(v))
def test_constant(self): fn1 = lambda: tf.constant('fn1') fn2 = lambda: tf.constant('fn2') expected = lambda v: b'fn1' if v else b'fn2' p = tf.placeholder(tf.bool, []) for v in [True, False, 1, 0]: o = utils.smart_cond(p, fn1, fn2) with self.test_session(): self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
def test_tensors(self): fn1 = lambda: constant_op.constant(0) - constant_op.constant(1) fn2 = lambda: constant_op.constant(0) - constant_op.constant(2) expected = lambda v: -1 if v else -2 p = array_ops.placeholder(dtypes.bool, []) for v in [True, False, 1, 0]: o = utils.smart_cond(p, fn1, fn2) with self.test_session(): self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
def _build_statistics_second_moment(self, input_batch, reduction_indices, use_batch_stats): self._moving_mean = tf.get_variable( "moving_mean", shape=self._mean_shape, collections=[ tf.GraphKeys.MOVING_AVERAGE_VARIABLES, tf.GraphKeys.VARIABLES ], initializer=tf.zeros_initializer, trainable=False) self._moving_second_moment = tf.get_variable( "moving_second_moment", shape=self._mean_shape, collections=[ tf.GraphKeys.MOVING_AVERAGE_VARIABLES, tf.GraphKeys.VARIABLES ], initializer=tf.ones_initializer(), trainable=False) self._moving_variance = tf.sub(self._moving_second_moment, tf.square(self._moving_mean), name="moving_variance") def build_batch_stats(): shift = tf.add(self._moving_mean, 0) counts, shifted_sum_x, shifted_sum_x2, _ = tf.nn.sufficient_statistics( input_batch, reduction_indices, keep_dims=True, shift=shift, name="batch_norm_ss") mean, variance = tf.nn.normalize_moments(counts, shifted_sum_x, shifted_sum_x2, shift, name="normalize_moments") second_moment = variance + tf.square(mean) return mean, variance, second_moment def build_moving_stats(): return ( tf.identity(self._moving_mean), tf.identity(self._moving_variance), tf.identity(self._moving_second_moment), ) mean, variance, second_moment = utils.smart_cond( use_batch_stats, build_batch_stats, build_moving_stats, ) return mean, variance, second_moment
def test_variable(self): fn1 = lambda: variables.Variable('fn1') fn2 = lambda: variables.Variable('fn2') expected = lambda v: b'fn1' if v else b'fn2' p = array_ops.placeholder(dtypes.bool, []) for v in [True, False, 1, 0]: o = utils.smart_cond(p, fn1, fn2) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
def test_variable(self): fn1 = lambda: tf.Variable('fn1') fn2 = lambda: tf.Variable('fn2') expected = lambda v: b'fn1' if v else b'fn2' p = tf.placeholder(tf.bool, []) for v in [True, False, 1, 0]: o = utils.smart_cond(p, fn1, fn2) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
def batch_normalization(x): beta = tf.get_variable('beta', [x.get_shape()[-1]], dtype, tf.zeros_initializer()) gamma = tf.get_variable('gamma', [x.get_shape()[-1]], dtype, tf.ones_initializer()) mv_mean = tf.get_variable('mv_mean', [x.get_shape()[-1]], dtype=dtype, initializer=tf.zeros_initializer(), trainable=False) mv_var = tf.get_variable('mv_var', [x.get_shape()[-1]], dtype=dtype, initializer=tf.ones_initializer(), trainable=False) mean, variance = tf.nn.moments(x, [0], name='moments') tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, assign_moving_average(mv_mean, mean, decay, zero_debias=True)) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, assign_moving_average(mv_var, variance, decay, zero_debias=False)) mean, variance = utils.smart_cond(is_training, lambda: (mean, variance), lambda: (mv_mean, mv_var)) return tf.nn.batch_normalization(x, mean, variance, beta, gamma, 1e-6)
def test_variable(self): fn1 = lambda: tf.Variable('fn1') fn2 = lambda: tf.Variable('fn2') expected = lambda v: b'fn1' if v else b'fn2' p = tf.placeholder(tf.bool, []) for v in [True, False, 1, 0]: o = utils.smart_cond(p, fn1, fn2) with self.test_session() as sess: sess.run(tf.initialize_all_variables()) self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
def model_setup(self, args): with tf.variable_scope(args.log_prefix): self.init_global_step() self._create_placeholders() self._logger.info("Created placeholders.") self._create_embedding_matrix(args) self.kl_w = tf.log(1. + tf.exp((self.global_step - args.klw_b) * args.klw_w)) self.kl_w = tf.minimum(self.kl_w, 1.) / 100.0 #scale reweighted self.loss_l = self.get_loss_l(args) self.train_unlabel = tf.greater(self.global_step, args.num_pretrain_steps) self.loss_u = smart_cond(self.train_unlabel, lambda: self.get_loss_u(args), lambda: tf.constant(0.)) tf.summary.scalar('train_unlabel', tf.to_int64(self.train_unlabel)) tf.summary.scalar('loss_u', self.loss_u) self.loss = self.loss_l + self.loss_u tf.summary.scalar('loss', self.loss) with tf.variable_scope(args.log_prefix): # optimizer #embd_var = self.embedding_matrix #other_var_list = [v for v in tf.trainable_variables() if v.name != embd_var.name] learning_rate = tf.train.exponential_decay(args.learning_rate, self.global_step, args.decay_steps, args.decay_rate, staircase=True) self.train_op = self.training_op(self.loss, #tf.trainable_variables(), tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES, scope=args.log_prefix), grad_clip=args.grad_clip, max_norm=args.max_norm, train_embd=True, learning_rate=args.learning_rate,) self._logger.info("Created SemiClassifier Model.") self._create_saver(args) self._logger.info('Created Saver') self.merged = tf.summary.merge_all() """ Create beam search layer self.beam_output_cur, self.beam_scores_cur = self._create_beam_search_layer( init_state=yz, dec_step_func=cur_dec_func, cell=cur_cell, embedding_matrix=self.embedding_matrix, vocab_size=args.vocab_size, num_layers=args.num_layers,) self._logger.info('Created Beam Search Layer') """ vt = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES, scope=args.log_prefix) vs = tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES, scope=args.log_prefix) return vt, vs
def dropout_selu(x, rate, alpha=-1.7580993408473766, fixedPointMean=0.0, fixedPointVar=1.0, noise_shape=None, seed=None, name=None, training=False): """Dropout to a value with rescaling.""" def dropout_selu_impl(x, rate, alpha, noise_shape, seed, name): keep_prob = 1.0 - rate x = ops.convert_to_tensor(x, name="x") if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1: raise ValueError( "keep_prob must be a scalar tensor or a float in the " "range (0, 1], got %g" % keep_prob) keep_prob = ops.convert_to_tensor(keep_prob, dtype=x.dtype, name="keep_prob") keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) alpha = ops.convert_to_tensor(alpha, dtype=x.dtype, name="alpha") alpha.get_shape().assert_is_compatible_with(tensor_shape.scalar()) if tensor_util.constant_value(keep_prob) == 1: return x noise_shape = noise_shape if noise_shape is not None else array_ops.shape( x) random_tensor = keep_prob random_tensor += random_ops.random_uniform(noise_shape, seed=seed, dtype=x.dtype) binary_tensor = math_ops.floor(random_tensor) ret = x * binary_tensor + alpha * (1 - binary_tensor) a = math_ops.sqrt( fixedPointVar / (keep_prob * ((1 - keep_prob) * math_ops.pow(alpha - fixedPointMean, 2) + fixedPointVar))) b = fixedPointMean - a * (keep_prob * fixedPointMean + (1 - keep_prob) * alpha) ret = a * ret + b ret.set_shape(x.get_shape()) return ret with ops.name_scope(name, "dropout", [x]) as name: return utils.smart_cond( training, lambda: dropout_selu_impl(x, rate, alpha, noise_shape, seed, name), lambda: array_ops.identity(x))
def _get_elbo_label(self, inp, tgt, msk, y, args): """ Build encoder and decoders """ xlen = tf.to_int32(tf.reduce_sum(msk, axis=1)) enc_state = self._create_encoder( tgt, seqlen=xlen, scope_name='enc', args=args) with tf.variable_scope('latent'): y_enc_in = tf.contrib.layers.fully_connected(y, args.dim_z, scope='y_enc_in') pst_in = tf.concat([y_enc_in, enc_state], axis=1) mu_pst = tf.contrib.layers.fully_connected(pst_in, args.dim_z, tf.nn.tanh, scope='mu_posterior') logvar_pst = tf.contrib.layers.fully_connected(pst_in, args.dim_z, tf.nn.tanh, scope='logvar_posterior') mu_pri = tf.zeros_like(mu_pst) logvar_pri = tf.ones_like(logvar_pst) dist_pri = tf.contrib.distributions.Normal(mu=mu_pri, sigma=tf.exp(logvar_pri)) dist_pst = tf.contrib.distributions.Normal(mu=mu_pst, sigma=tf.exp(logvar_pst)) kl_loss = tf.contrib.distributions.kl(dist_pst, dist_pri) kl_loss = tf.reduce_sum(kl_loss, axis=1) with st.value_type(st.SampleValue(stop_gradient=False)): z_st_pri = st.StochasticTensor(dist_pri, name='z_pri') z_st_pst = st.StochasticTensor(dist_pst, name='z_pst') z = smart_cond(self.is_training, lambda: z_st_pst, lambda: z_st_pri) z_ext = tf.contrib.layers.fully_connected(tf.reshape(z, [-1, args.dim_z]), args.num_units, scope='extend_z') xlen = tf.to_int32(tf.reduce_sum(msk, axis=1)) outs, proj, dec_func, cell = self._create_decoder( inp, seqlen=xlen, label_oh=y, init_state=z_ext, scope_name='dec', args=args) # build loss layers recons_loss = self._create_softmax_layer( proj=proj, dec_outs=outs, targets=tgt, weights=msk, scope_name='loss', args=args) return recons_loss, kl_loss
def _build_update_ops(self, mean, variance, is_training): """Builds the moving average update ops when using moving variance. Args: mean: The mean value to update with. variance: The variance value to update with. is_training: Boolean Tensor to indicate if we're currently in training mode. Returns: Tuple of `(update_mean_op, update_variance_op)` when `is_training` is or could be `True`. Returns `None` when `is_training=False`. """ def build_update_ops(): """Builds the exponential moving average update ops.""" update_mean_op = moving_averages.assign_moving_average( variable=self._moving_mean, value=mean, decay=self._decay_rate, zero_debias=False, name="update_moving_mean").op update_variance_op = moving_averages.assign_moving_average( variable=self._moving_variance, value=variance, decay=self._decay_rate, zero_debias=False, name="update_moving_variance").op return update_mean_op, update_variance_op def build_no_ops(): return (tf.no_op(), tf.no_op()) # Only make the ops if we know that `is_training=True`, or the value of # `is_training` is unknown. is_training_const = utils.constant_value(is_training) if is_training_const is None or is_training_const: update_mean_op, update_variance_op = utils.smart_cond( is_training, build_update_ops, build_no_ops, ) return (update_mean_op, update_variance_op) else: return None
def _build_update_ops_second_moment(self, mean, second_moment, is_training): """Builds the moving average update ops when using the moving second moment. Args: mean: The mean value to update with. second_moment: The second_moment value to update with. is_training: Boolean Tensor to indicate if we're currently in training mode. """ def build_update_ops(): """Builds the exponential moving average update ops.""" update_mean_op = moving_averages.assign_moving_average( variable=self._moving_mean, value=mean, decay=self._decay_rate, name="update_moving_mean").op update_second_moment_op = moving_averages.assign_moving_average( variable=self._moving_second_moment, value=second_moment, decay=self._decay_rate, name="update_moving_second_moment").op return update_mean_op, update_second_moment_op def build_no_ops(): return (tf.no_op(), tf.no_op()) # Only make the ops if we know that `is_training=True`, or the value of # `is_training` is unknown. is_training_const = utils.constant_value(is_training) if is_training_const is None or is_training_const: update_mean_op, update_second_moment_op = utils.smart_cond( is_training, build_update_ops, build_no_ops, ) # Every new connection creates a new op which adds its contribution # to the running average when ran. tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean_op) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_second_moment_op)
def _fused_batch_norm_op(self, input_batch, mean, variance, use_batch_stats): """Creates a fused batch normalization op.""" # The fused batch norm expects the mean, variance, gamma and beta # tensors to have dimension 1, so we flatten them to remove the # extra dimensions. gamma_flatten = tf.reshape(self._gamma, shape=(-1, )) beta_flatten = tf.reshape(self._beta, shape=(-1, )) flatten_mean = tf.reshape(mean, shape=(-1, )) flatten_variance = tf.reshape(variance, shape=(-1, )) use_batch_stats = tf.convert_to_tensor(use_batch_stats) common_args = { "scale": gamma_flatten, "offset": beta_flatten, "epsilon": self._eps, "data_format": self._infer_fused_data_format(input_batch), "name": "batch_norm" } def use_batch_stats_fused_batch_norm(): return tf.nn.fused_batch_norm(input_batch, mean=None, variance=None, is_training=True, **common_args) def moving_average_fused_batch_norm(): return tf.nn.fused_batch_norm(input_batch, mean=flatten_mean, variance=flatten_variance, is_training=False, **common_args) batch_norm_op, mean, variance = utils.smart_cond( use_batch_stats, use_batch_stats_fused_batch_norm, moving_average_fused_batch_norm) return batch_norm_op, mean, variance
def batch_norm_mine_old(inputs, decay=0.999, center=True, scale=False, epsilon=0.001, activation_fn=None, param_initializers=None, param_regularizers=None, updates_collections=ops.GraphKeys.UPDATE_OPS, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, batch_weights=None, fused=False, data_format=DATA_FORMAT_NHWC, zero_debias_moving_mean=False, scope=None, renorm=False, renorm_clipping=None, renorm_decay=0.99): """ This earlier version of my modification to batch norm uses current_mean and current_variance if is_training is True and moving_mean and moving_variance otherwise. This was leading a large divergence between the results depending upon whether the is_training set to True or not. I think ideally it should always use moving_mean and moving_variance. batch_norm_mine does this. Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. copy of tensorflow.contrib.layers Args: inputs: A tensor with 2 or more dimensions, where the first dimension has `batch_size`. The normalization is over all but the last dimension if `data_format` is `NHWC` and the second dimension if `data_format` is `NCHW`. decay: Decay for the moving average. Reasonable values for `decay` are close to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc. Lower `decay` value (recommend trying `decay`=0.9) if model experiences reasonably good training performance but poor validation and/or test performance. Try zero_debias_moving_mean=True for improved stability. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. epsilon: Small float added to variance to avoid dividing by zero. activation_fn: Activation function, default set to None to skip it and maintain a linear activation. param_initializers: Optional initializers for beta, gamma, moving mean and moving variance. param_regularizers: Optional regularizer for beta and gamma. updates_collections: Collections to collect the update ops for computation. The updates_ops need to be executed with the train_op. If None, a control dependency would be added to make sure the updates are computed in place. is_training: Whether or not the layer is in training mode. In training mode it would accumulate the statistics of the moments into `moving_mean` and `moving_variance` using an exponential moving average with the given `decay`. When it is not in training mode then it would use the values of the `moving_mean` and the `moving_variance`. reuse: Whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. variables_collections: Optional collections for the variables. outputs_collections: Collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). batch_weights: An optional tensor of shape `[batch_size]`, containing a frequency weight for each batch item. If present, then the batch normalization uses weighted mean and variance. (This can be used to correct for bias in training example selection.) fused: Use nn.fused_batch_norm if True, nn.batch_normalization otherwise. data_format: A string. `NHWC` (default) and `NCHW` are supported. zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new pair of variables 'moving_mean/biased' and 'moving_mean/local_step'. scope: Optional scope for `variable_scope`. renorm: Whether to use Batch Renormalization (https://arxiv.org/abs/1702.03275). This adds extra variables during training. The inference is the same for either value of this parameter. renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to scalar `Tensors` used to clip the renorm correction. The correction `(r, d)` is used as `corrected_value = normalized_value * r + d`, with `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, dmax are set to inf, 0, inf, respectively. renorm_decay: Momentum used to update the moving means and standard deviations with renorm. Unlike `momentum`, this affects training and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `decay` is still applied to get the means and variances for inference. Returns: A `Tensor` representing the output of the operation. Raises: ValueError: If `batch_weights` is not None and `fused` is True. ValueError: If `param_regularizers` is not None and `fused` is True. ValueError: If `data_format` is neither `NHWC` nor `NCHW`. ValueError: If the rank of `inputs` is undefined. ValueError: If rank or channels dimension of `inputs` is undefined. """ if fused: if batch_weights is not None: raise ValueError('Weighted mean and variance is not currently ' 'supported for fused batch norm.') if param_regularizers is not None: raise ValueError('Regularizers are not currently ' 'supported for fused batch norm.') if renorm: raise ValueError('Renorm is not supported for fused batch norm.') return _fused_batch_norm( inputs, decay=decay, center=center, scale=scale, epsilon=epsilon, activation_fn=activation_fn, param_initializers=param_initializers, updates_collections=updates_collections, is_training=is_training, reuse=reuse, variables_collections=variables_collections, outputs_collections=outputs_collections, trainable=trainable, data_format=data_format, zero_debias_moving_mean=zero_debias_moving_mean, scope=scope) if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') layer_variable_getter = _build_variable_getter() with variable_scope.variable_scope( scope, 'BatchNorm', [inputs], reuse=reuse, custom_getter=layer_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) # Determine whether we can use the core layer class. if (batch_weights is None and updates_collections is ops.GraphKeys.UPDATE_OPS and not zero_debias_moving_mean): # Use the core layer class. axis = 1 if data_format == DATA_FORMAT_NCHW else -1 if not param_initializers: param_initializers = {} beta_initializer = param_initializers.get('beta', init_ops.zeros_initializer()) gamma_initializer = param_initializers.get('gamma', init_ops.ones_initializer()) moving_mean_initializer = param_initializers.get( 'moving_mean', init_ops.zeros_initializer()) moving_variance_initializer = param_initializers.get( 'moving_variance', init_ops.ones_initializer()) if not param_regularizers: param_regularizers = {} beta_regularizer = param_regularizers.get('beta') gamma_regularizer = param_regularizers.get('gamma') layer = normalization_layers.BatchNormalization( axis=axis, momentum=decay, epsilon=epsilon, center=center, scale=scale, beta_initializer=beta_initializer, gamma_initializer=gamma_initializer, moving_mean_initializer=moving_mean_initializer, moving_variance_initializer=moving_variance_initializer, beta_regularizer=beta_regularizer, gamma_regularizer=gamma_regularizer, trainable=trainable, renorm=renorm, renorm_clipping=renorm_clipping, renorm_momentum=renorm_decay, name=sc.name, _scope=sc, _reuse=reuse) outputs = layer.apply(inputs, training=is_training) # Add variables to collections. _add_variable_to_collections( layer.moving_mean, variables_collections, 'moving_mean') _add_variable_to_collections( layer.moving_variance, variables_collections, 'moving_variance') if layer.beta: _add_variable_to_collections(layer.beta, variables_collections, 'beta') if layer.gamma: _add_variable_to_collections( layer.gamma, variables_collections, 'gamma') if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.original_name_scope, outputs) # Not supported by layer class: batch_weights argument, # and custom updates_collections. In that case, use the legacy BN # implementation. # Custom updates collections are not supported because the update logic # is different in this case, in particular w.r.t. "forced updates" and # update op reuse. if renorm: raise ValueError('renorm is not supported with batch_weights, ' 'updates_collections or zero_debias_moving_mean') inputs_shape = inputs.get_shape() inputs_rank = inputs_shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) dtype = inputs.dtype.base_dtype if batch_weights is not None: batch_weights = ops.convert_to_tensor(batch_weights) inputs_shape[0:1].assert_is_compatible_with(batch_weights.get_shape()) # Reshape batch weight values so they broadcast across inputs. nshape = [-1] + [1 for _ in range(inputs_rank - 1)] batch_weights = array_ops.reshape(batch_weights, nshape) if data_format == DATA_FORMAT_NCHW: moments_axes = [0] + list(range(2, inputs_rank)) params_shape = inputs_shape[1:2] # For NCHW format, rather than relying on implicit broadcasting, we # explicitly reshape the params to params_shape_broadcast when computing # the moments and the batch normalization. params_shape_broadcast = list( [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)]) else: moments_axes = list(range(inputs_rank - 1)) params_shape = inputs_shape[-1:] params_shape_broadcast = None if not params_shape.is_fully_defined(): raise ValueError('Inputs %s has undefined channels dimension %s.' % ( inputs.name, params_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None if not param_initializers: param_initializers = {} if center: beta_collections = utils.get_variable_collections(variables_collections, 'beta') beta_initializer = param_initializers.get('beta', init_ops.zeros_initializer()) beta = variables.model_variable('beta', shape=params_shape, dtype=dtype, initializer=beta_initializer, collections=beta_collections, trainable=trainable) if scale: gamma_collections = utils.get_variable_collections(variables_collections, 'gamma') gamma_initializer = param_initializers.get('gamma', init_ops.ones_initializer()) gamma = variables.model_variable('gamma', shape=params_shape, dtype=dtype, initializer=gamma_initializer, collections=gamma_collections, trainable=trainable) # Create moving_mean and moving_variance variables and add them to the # appropriate collections. We disable variable partitioning while creating # them, because assign_moving_average is not yet supported for partitioned # variables. partitioner = variable_scope.get_variable_scope().partitioner try: variable_scope.get_variable_scope().set_partitioner(None) moving_mean_collections = utils.get_variable_collections( variables_collections, 'moving_mean') moving_mean_initializer = param_initializers.get( 'moving_mean', init_ops.zeros_initializer()) moving_mean = variables.model_variable( 'moving_mean', shape=params_shape, dtype=dtype, initializer=moving_mean_initializer, trainable=False, collections=moving_mean_collections) moving_variance_collections = utils.get_variable_collections( variables_collections, 'moving_variance') moving_variance_initializer = param_initializers.get( 'moving_variance', init_ops.ones_initializer()) moving_variance = variables.model_variable( 'moving_variance', shape=params_shape, dtype=dtype, initializer=moving_variance_initializer, trainable=False, collections=moving_variance_collections) finally: variable_scope.get_variable_scope().set_partitioner(partitioner) # If `is_training` doesn't have a constant value, because it is a `Tensor`, # a `Variable` or `Placeholder` then is_training_value will be None and # `needs_moments` will be true. is_training_value = utils.constant_value(is_training) need_moments = is_training_value is None or is_training_value if need_moments: # Calculate the moments based on the individual batch. if batch_weights is None: if data_format == DATA_FORMAT_NCHW: mean, _ = nn.moments(inputs, moments_axes, keep_dims=True) variance,_ = nn.moments( (inputs-moving_mean)**2, moments_axes, keep_dims=True) mean = array_ops.reshape(mean, [-1]) variance = array_ops.reshape(variance, [-1]) else: mean, _ = nn.moments(inputs, moments_axes) variance, _ = nn.moments( (inputs-moving_mean)**2, moments_axes) else: if data_format == DATA_FORMAT_NCHW: mean, _ = nn.weighted_moments(inputs, moments_axes, batch_weights, keep_dims=True) variance, _ = nn.weighted_moments( (inputs-moving_mean)**2, moments_axes, batch_weights, keep_dims=True) mean = array_ops.reshape(mean, [-1]) variance = array_ops.reshape(variance, [-1]) else: mean, _ = nn.weighted_moments(inputs, moments_axes, batch_weights) variance, _ = nn.weighted_moments( (inputs-moving_mean)**2, moments_axes, batch_weights) moving_vars_fn = lambda: (moving_mean, moving_variance) if updates_collections is None: def _force_updates(): """Internal function forces updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) with ops.control_dependencies([update_moving_mean, update_moving_variance]): return array_ops.identity(mean), array_ops.identity(variance) mean, variance = utils.smart_cond(is_training, _force_updates, moving_vars_fn) else: def _delay_updates(): """Internal function that delay updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) return update_moving_mean, update_moving_variance update_mean, update_variance = utils.smart_cond(is_training, _delay_updates, moving_vars_fn) ops.add_to_collections(updates_collections, update_mean) ops.add_to_collections(updates_collections, update_variance) # Use computed moments during training and moving_vars otherwise. vars_fn = lambda: (mean, variance) mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn) else: mean, variance = moving_mean, moving_variance if data_format == DATA_FORMAT_NCHW: mean = array_ops.reshape(mean, params_shape_broadcast) variance = array_ops.reshape(variance, params_shape_broadcast) beta = array_ops.reshape(beta, params_shape_broadcast) if gamma is not None: gamma = array_ops.reshape(gamma, params_shape_broadcast) # Compute batch_normalization. outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon) outputs.set_shape(inputs_shape) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.original_name_scope, outputs)
def batch_norm(inputs, decay=0.999, center=True, scale=False, epsilon=0.001, activation_fn=None, initializers={}, updates_collections=None, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=False, scope=None): """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" Sergey Ioffe, Christian Szegedy Can be used as a normalizer function for conv2d and fully_connected. Note: When is_training is True the moving_mean and moving_variance need to be updated, by default the update_ops are placed in tf.GraphKeys.UPDATE_OPS so they need to be added as a dependency to the train_op, example: update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) if update_ops: updates = tf.group(*update_ops) total_loss = control_flow_ops.with_dependencies([updates], total_loss) One can set update_collections=None to force the updates in place, but that can have speed penalty, specially in distributed settings. Args: inputs: a tensor with 2 or more dimensions, where the first dimension has `batch_size`. The normalization is over all but the last dimension. decay: decay for the moving average. center: If True, subtract `beta`. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. epsilon: small float added to variance to avoid dividing by zero. activation_fn: activation function, default set to None to skip it and maintain a linear activation. updates_collections: collections to collect the update ops for computation. The updates_ops need to be excuted with the train_op. If None, a control dependency would be added to make sure the updates are computed in place. is_training: whether or not the layer is in training mode. In training mode it would accumulate the statistics of the moments into `moving_mean` and `moving_variance` using an exponential moving average with the given `decay`. When it is not in training mode then it would use the values of the `moving_mean` and the `moving_variance`. reuse: whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. variables_collections: optional collections for the variables. outputs_collections: collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). scope: Optional scope for `variable_scope`. Returns: A `Tensor` representing the output of the operation. Raises: ValueError: if rank or last dimension of `inputs` is undefined. """ with variable_scope.variable_scope(scope, 'BatchNorm', [inputs], reuse=reuse) as sc: inputs = ops.convert_to_tensor(inputs) inputs_shape = inputs.get_shape() inputs_rank = inputs_shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) dtype = inputs.dtype.base_dtype axis = list(range(inputs_rank - 1)) params_shape = inputs_shape[-1:] if not params_shape.is_fully_defined(): raise ValueError('Inputs %s has undefined last dimension %s.' % (inputs.name, params_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None # Create moving_mean and moving_variance variables and add them to the # appropiate collections. moving_mean_initializer = initializers.get('moving_mean', init_ops.zeros_initializer) moving_mean = variables.model_variable( 'moving_mean', shape=params_shape, dtype=dtype, initializer=moving_mean_initializer, trainable=False) moving_variance_initializer = initializers.get( 'moving_variance', init_ops.ones_initializer) moving_variance = variables.model_variable( 'moving_variance', shape=params_shape, dtype=dtype, initializer=moving_variance_initializer, trainable=False) # If `is_training` doesn't have a constant value, because it is a `Tensor`, # a `Variable` or `Placeholder` then is_training_value will be None and # `needs_moments` will be true. is_training_value = utils.constant_value(is_training) need_moments = is_training_value is None or is_training_value if need_moments: # Calculate the moments based on the individual batch. # Use a copy of moving_mean as a shift to compute more reliable moments. shift = math_ops.add(moving_mean, 0) mean, variance = nn.moments(inputs, axis, shift=shift) moving_vars_fn = lambda: (moving_mean, moving_variance) if updates_collections is None: def _force_updates(): """Internal function forces updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay) with ops.control_dependencies( [update_moving_mean, update_moving_variance]): return array_ops.identity(mean), array_ops.identity( variance) mean, variance = utils.smart_cond(is_training, _force_updates, moving_vars_fn) else: def _delay_updates(): """Internal function that delay updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay) return update_moving_mean, update_moving_variance update_mean, update_variance = utils.smart_cond( is_training, _delay_updates, moving_vars_fn) ops.add_to_collections(updates_collections, update_mean) ops.add_to_collections(updates_collections, update_variance) # Use computed moments during training and moving_vars otherwise. vars_fn = lambda: (mean, variance) mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn) else: mean, variance = moving_mean, moving_variance # Compute batch_normalization. outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon) outputs.set_shape(inputs_shape) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.original_name_scope, outputs)
def call(self, inputs, training=False): # First, compute the axes along which to reduce the mean / variance, # as well as the broadcast shape to be used for all parameters. input_shape = inputs.get_shape() ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis].value # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1]) # Determine a boolean value for `training`: could be True, False, or None. training_value = utils.constant_value(training) if needs_broadcasting: # In this case we must explictly broadcast all parameters. if self.center: broadcast_beta = array_ops.reshape(self.beta, broadcast_shape) else: broadcast_beta = None if self.scale: broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape) else: broadcast_gamma = None # Determines moments if training_value is not False: if needs_broadcasting: broadcast_mean, broadcast_variance = nn.moments(inputs, reduction_axes, keep_dims=True) mean = array_ops.reshape(broadcast_mean, [-1]) variance = array_ops.reshape(broadcast_variance, [-1]) else: mean, variance = nn.moments(inputs, reduction_axes) # Prepare updates if necessary. if not self.updates: mean_update = moving_averages.assign_moving_average( self.moving_mean, mean, self.momentum, zero_debias=False) variance_update = moving_averages.assign_moving_average( self.moving_variance, variance, self.momentum, zero_debias=False) # In the future this should be refactored into a self.add_update # methods in order to allow for instance-based BN layer sharing # across unrelated input streams (e.g. like in Keras). self.updates.append(mean_update) self.updates.append(variance_update) # Normalize batch. We do this inside separate functions for training # and inference so as to avoid evaluating both branches. def normalize_in_test(): if needs_broadcasting: broadcast_moving_mean = array_ops.reshape( self.moving_mean, broadcast_shape) broadcast_moving_variance = array_ops.reshape( self.moving_variance, broadcast_shape) arg_mean = broadcast_moving_mean if needs_broadcasting else self.moving_mean arg_variance = broadcast_moving_variance if needs_broadcasting else self.moving_variance arg_beta = broadcast_beta if needs_broadcasting else ( self.beta if self.center else None) arg_gamma = broadcast_gamma if needs_broadcasting else ( self.gamma if self.scale else None) if self.quantizer is None: return nn.batch_normalization(inputs, arg_mean, arg_variance, arg_beta, arg_gamma, self.epsilon) else: return qbatch_normalization(inputs, arg_mean, arg_variance, arg_beta, arg_gamma, self.epsilon, self.quantizer) def normalize_in_training(): arg_mean = broadcast_mean if needs_broadcasting else mean arg_variance = broadcast_variance if needs_broadcasting else variance arg_beta = broadcast_beta if needs_broadcasting else ( self.beta if self.center else None) arg_gamma = broadcast_gamma if needs_broadcasting else ( self.gamma if self.scale else None) if self.quantizer is None: return nn.batch_normalization(inputs, arg_mean, arg_variance, arg_beta, arg_gamma, self.epsilon) else: return qbatch_normalization(inputs, arg_mean, arg_variance, arg_beta, arg_gamma, self.epsilon, self.quantizer) return utils.smart_cond(training, normalize_in_training, normalize_in_test)
def _get_elbo_label(self, inp, tgt, msk, label, args): """ Build encoder and decoders """ xlen = tf.to_int32(tf.reduce_sum(msk, axis=1)) enc_state = self._create_encoder(tgt, seqlen=xlen, scope_name='enc', args=args) enc_state = tf.nn.dropout(enc_state, self.keep_prob_plh) enc_state = tf.contrib.layers.fully_connected( enc_state, num_outputs=args.num_units, activation_fn=None, scope='x_to_a') enc_state = tf.contrib.layers.batch_norm( enc_state, center=True, scale=True, is_training=self.is_training_plh, scope='bn_a') enc_state = tf.tanh(enc_state) enc_state = tf.nn.dropout(enc_state, self.keep_prob_plh) label_oh = tf.gather(tf.eye(args.num_classes), label) with tf.variable_scope('latent'): y_enc_in = tf.contrib.layers.fully_connected(label_oh, args.dim_z, scope='y_enc_in') y_enc_in = tf.nn.dropout(y_enc_in, self.keep_prob_plh) pst_in = tf.concat([y_enc_in, enc_state], axis=1) pst_in = tf.contrib.layers.fully_connected(pst_in, args.num_units, None, scope='pst_in_dense') pst_in = tf.contrib.layers.batch_norm( pst_in, center=True, scale=True, is_training=self.is_training_plh, scope='pst_in_bn') pst_in = tf.tanh(pst_in) pst_in = tf.nn.dropout(pst_in, self.keep_prob_plh) mu_pst = tf.contrib.layers.fully_connected(pst_in, args.dim_z, tf.nn.tanh, scope='mu_posterior') logvar_pst = tf.contrib.layers.fully_connected( pst_in, args.dim_z, tf.nn.tanh, scope='logvar_posterior') mu_pri = tf.zeros_like(mu_pst) logvar_pri = tf.ones_like(logvar_pst) dist_pri = tf.contrib.distributions.Normal( mu=mu_pri, sigma=tf.exp(logvar_pri)) dist_pst = tf.contrib.distributions.Normal( mu=mu_pst, sigma=tf.exp(logvar_pst)) kl_loss = tf.contrib.distributions.kl(dist_pst, dist_pri) kl_loss = tf.reduce_sum(kl_loss, axis=1) with st.value_type(st.SampleValue(stop_gradient=False)): z_st_pri = st.StochasticTensor(dist_pri, name='z_pri') z_st_pst = st.StochasticTensor(dist_pst, name='z_pst') z = smart_cond(self.is_training_plh, lambda: z_st_pst, lambda: z_st_pri) z_ext = tf.contrib.layers.fully_connected(tf.reshape( z, [-1, args.dim_z]), args.num_units, scope='extend_z') z_ext = tf.nn.dropout(z_ext, self.keep_prob_plh) yz = tf.concat([z_ext, label_oh], axis=1) yz = tf.contrib.layers.fully_connected(yz, args.num_units, None, scope='yz_dense') yz = tf.contrib.layers.batch_norm(yz, center=True, scale=True, is_training=self.is_training_plh, scope='yz_bn') yz = tf.tanh(yz) yz = tf.nn.dropout(yz, self.keep_prob_plh) xlen = tf.to_int32(tf.reduce_sum(msk, axis=1)) outs, proj, dec_func, cell = self._create_decoder( inp, mask=msk, label_oh=label_oh, init_state=yz, #tf.contrib.rnn.LSTMStateTuple(yz, yz), scope_name='dec', args=args) outs = tf.nn.dropout(outs, self.keep_prob_plh) # build loss layers recons_loss = self._create_softmax_layer(proj=proj, dec_outs=outs, targets=tgt, weights=msk, scope_name='loss', args=args) recons_loss = recons_loss * msk return recons_loss, kl_loss
def batch_norm_backbone(inputs, decay=0.999, center=True, scale=False, epsilon=0.001, activation_fn=None, param_initializers=None, param_regularizers=None, updates_collections=ops.GraphKeys.UPDATE_OPS, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, batch_weights=None, fused=None, data_format=DATA_FORMAT_NHWC, zero_debias_moving_mean=False, scope=None, renorm=False, renorm_clipping=None, renorm_decay=0.99, adjustment=None, tower_config=None): """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" Sergey Ioffe, Christian Szegedy Can be used as a normalizer function for conv2d and fully_connected. The normalization is over all but the last dimension if `data_format` is `NHWC` and all but the second dimension if `data_format` is `NCHW`. In case of a 2D tensor this corresponds to the batch dimension, while in case of a 4D tensor this corresponds to the batch and space dimensions. Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they need to be added as a dependency to the `train_op`. For example: ```python update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss) ``` One can set updates_collections=None to force the updates in place, but that can have a speed penalty, especially in distributed settings. Args: inputs: A tensor with 2 or more dimensions, where the first dimension has `batch_size`. The normalization is over all but the last dimension if `data_format` is `NHWC` and the second dimension if `data_format` is `NCHW`. decay: Decay for the moving average. Reasonable values for `decay` are close to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc. Lower `decay` value (recommend trying `decay`=0.9) if model experiences reasonably good training performance but poor validation and/or test performance. Try zero_debias_moving_mean=True for improved stability. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. epsilon: Small float added to variance to avoid dividing by zero. activation_fn: Activation function, default set to None to skip it and maintain a linear activation. param_initializers: Optional initializers for beta, gamma, moving mean and moving variance. param_regularizers: Optional regularizer for beta and gamma. updates_collections: Collections to collect the update ops for computation. The updates_ops need to be executed with the train_op. If None, a control dependency would be added to make sure the updates are computed in place. is_training: Whether or not the layer is in training mode. In training mode it would accumulate the statistics of the moments into `moving_mean` and `moving_variance` using an exponential moving average with the given `decay`. When it is not in training mode then it would use the values of the `moving_mean` and the `moving_variance`. reuse: Whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. variables_collections: Optional collections for the variables. outputs_collections: Collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). batch_weights: An optional tensor of shape `[batch_size]`, containing a frequency weight for each batch item. If present, then the batch normalization uses weighted mean and variance. (This can be used to correct for bias in training example selection.) fused: if `None` or `True`, use a faster, fused implementation if possible. If `False`, use the system recommended implementation. data_format: A string. `NHWC` (default) and `NCHW` are supported. zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new pair of variables 'moving_mean/biased' and 'moving_mean/local_step'. scope: Optional scope for `variable_scope`. renorm: Whether to use Batch Renormalization (https://arxiv.org/abs/1702.03275). This adds extra variables during training. The inference is the same for either value of this parameter. renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to scalar `Tensors` used to clip the renorm correction. The correction `(r, d)` is used as `corrected_value = normalized_value * r + d`, with `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, dmax are set to inf, 0, inf, respectively. renorm_decay: Momentum used to update the moving means and standard deviations with renorm. Unlike `momentum`, this affects training and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `decay` is still applied to get the means and variances for inference. adjustment: A function taking the `Tensor` containing the (dynamic) shape of the input tensor and returning a pair (scale, bias) to apply to the normalized values (before gamma and beta), only during training. For example, `adjustment = lambda shape: ( tf.random_uniform(shape[-1:], 0.93, 1.07), tf.random_uniform(shape[-1:], -0.1, 0.1))` will scale the normalized value by up to 7% up or down, then shift the result by up to 0.1 (with independent scaling and bias for each feature but shared across all examples), and finally apply gamma and/or beta. If `None`, no adjustment is applied. Returns: A `Tensor` representing the output of the operation. Raises: ValueError: If `data_format` is neither `NHWC` nor `NCHW`. ValueError: If the rank of `inputs` is undefined. ValueError: If rank or channels dimension of `inputs` is undefined. """ # if fused is None: # fused = True # Only use _fused_batch_norm if all of the following three # conditions are true: # (1) fused is set True; # (2) it is possible to use (currently it doesn't support batch weights, # renorm, and the case when rank is neither 2 nor 4); # (3) it is used with zero_debias_moving_mean, or an input shape of rank 2, # or non-default updates_collections (not implemented in # normalization_layers.BatchNormalization yet); otherwise use the fused # implementation in normalization_layers.BatchNormalization. # inputs = ops.convert_to_tensor(inputs) # rank = inputs.get_shape().ndims # possible_to_fuse = ( # batch_weights is None and not renorm and rank in [2, 4] and # adjustment is None) # if fused and possible_to_fuse and ( # zero_debias_moving_mean or rank == 2 or # updates_collections is not ops.GraphKeys.UPDATE_OPS): # return _fused_batch_norm( # inputs, # decay=decay, # center=center, # scale=scale, # epsilon=epsilon, # activation_fn=activation_fn, # param_initializers=param_initializers, # param_regularizers=param_regularizers, # updates_collections=updates_collections, # is_training=is_training, # reuse=reuse, # variables_collections=variables_collections, # outputs_collections=outputs_collections, # trainable=trainable, # data_format=data_format, # zero_debias_moving_mean=zero_debias_moving_mean, # scope=scope) if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') layer_variable_getter = _build_variable_getter() with variable_scope.variable_scope( scope, 'BatchNorm', [inputs], reuse=reuse, custom_getter=layer_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) # # Determine whether we can use the core layer class. # if (batch_weights is None and # updates_collections is ops.GraphKeys.UPDATE_OPS and # not zero_debias_moving_mean): # print("F**K !!!!") # # Use the core layer class. # axis = 1 if data_format == DATA_FORMAT_NCHW else -1 # if not param_initializers: # param_initializers = {} # beta_initializer = param_initializers.get('beta', # init_ops.zeros_initializer()) # gamma_initializer = param_initializers.get('gamma', # init_ops.ones_initializer()) # moving_mean_initializer = param_initializers.get( # 'moving_mean', init_ops.zeros_initializer()) # moving_variance_initializer = param_initializers.get( # 'moving_variance', init_ops.ones_initializer()) # if not param_regularizers: # param_regularizers = {} # beta_regularizer = param_regularizers.get('beta') # gamma_regularizer = param_regularizers.get('gamma') # layer = normalization_layers.BatchNormalization( # axis=axis, # momentum=decay, # epsilon=epsilon, # center=center, # scale=scale, # beta_initializer=beta_initializer, # gamma_initializer=gamma_initializer, # moving_mean_initializer=moving_mean_initializer, # moving_variance_initializer=moving_variance_initializer, # beta_regularizer=beta_regularizer, # gamma_regularizer=gamma_regularizer, # trainable=trainable, # renorm=renorm, # renorm_clipping=renorm_clipping, # renorm_momentum=renorm_decay, # adjustment=adjustment, # name=sc.name, # _scope=sc, # _reuse=reuse, # fused=fused) # outputs = layer.apply(inputs, training=is_training) # # # Add variables to collections. # _add_variable_to_collections(layer.moving_mean, variables_collections, # 'moving_mean') # _add_variable_to_collections(layer.moving_variance, variables_collections, # 'moving_variance') # if layer.beta is not None: # _add_variable_to_collections(layer.beta, variables_collections, 'beta') # if layer.gamma is not None: # _add_variable_to_collections(layer.gamma, variables_collections, # 'gamma') # # if activation_fn is not None: # outputs = activation_fn(outputs) # return utils.collect_named_outputs(outputs_collections, sc.name, outputs) # Not supported by layer class: batch_weights argument, # and custom updates_collections. In that case, use the legacy BN # implementation. # Custom updates collections are not supported because the update logic # is different in this case, in particular w.r.t. "forced updates" and # update op reuse. if renorm: raise ValueError('renorm is not supported with batch_weights, ' 'updates_collections or zero_debias_moving_mean') inputs_shape = inputs.get_shape() inputs_rank = inputs_shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) dtype = inputs.dtype.base_dtype if batch_weights is not None: batch_weights = ops.convert_to_tensor(batch_weights) inputs_shape[0:1].assert_is_compatible_with(batch_weights.get_shape()) # Reshape batch weight values so they broadcast across inputs. nshape = [-1] + [1 for _ in range(inputs_rank - 1)] batch_weights = array_ops.reshape(batch_weights, nshape) if data_format == DATA_FORMAT_NCHW: moments_axes = [0] + list(range(2, inputs_rank)) params_shape = inputs_shape[1:2] # For NCHW format, rather than relying on implicit broadcasting, we # explicitly reshape the params to params_shape_broadcast when computing # the moments and the batch normalization. params_shape_broadcast = list( [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)]) else: moments_axes = list(range(inputs_rank - 1)) params_shape = inputs_shape[-1:] params_shape_broadcast = None if not params_shape.is_fully_defined(): raise ValueError('Inputs %s has undefined channels dimension %s.' % (inputs.name, params_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None if not param_initializers: param_initializers = {} if center: beta_collections = utils.get_variable_collections(variables_collections, 'beta') beta_initializer = param_initializers.get('beta', init_ops.zeros_initializer()) beta = variables.model_variable( 'beta', shape=params_shape, dtype=dtype, initializer=beta_initializer, collections=beta_collections, trainable=trainable) if scale: gamma_collections = utils.get_variable_collections( variables_collections, 'gamma') gamma_initializer = param_initializers.get('gamma', init_ops.ones_initializer()) gamma = variables.model_variable( 'gamma', shape=params_shape, dtype=dtype, initializer=gamma_initializer, collections=gamma_collections, trainable=trainable) # Create moving_mean and moving_variance variables and add them to the # appropriate collections. We disable variable partitioning while creating # them, because assign_moving_average is not yet supported for partitioned # variables (this needs to be handled carefully, as it may break # the checkpoint backward compatibility). with variable_scope.variable_scope( variable_scope.get_variable_scope()) as local_scope: local_scope.set_partitioner(None) moving_mean_collections = utils.get_variable_collections( variables_collections, 'moving_mean') moving_mean_initializer = param_initializers.get( 'moving_mean', init_ops.zeros_initializer()) moving_mean = variables.model_variable( 'moving_mean', shape=params_shape, dtype=dtype, initializer=moving_mean_initializer, trainable=False, collections=moving_mean_collections) moving_variance_collections = utils.get_variable_collections( variables_collections, 'moving_variance') moving_variance_initializer = param_initializers.get( 'moving_variance', init_ops.ones_initializer()) moving_variance = variables.model_variable( 'moving_variance', shape=params_shape, dtype=dtype, initializer=moving_variance_initializer, trainable=False, collections=moving_variance_collections) # If `is_training` doesn't have a constant value, because it is a `Tensor`, # a `Variable` or `Placeholder` then is_training_value will be None and # `needs_moments` will be true. is_training_value = utils.constant_value(is_training) need_moments = is_training_value is None or is_training_value if need_moments: # Calculate the moments based on the individual batch. if batch_weights is None: if data_format == DATA_FORMAT_NCHW: mean, variance = moments(inputs, moments_axes, tower_config=tower_config, keep_dims=True) mean = array_ops.reshape(mean, [-1]) variance = array_ops.reshape(variance, [-1]) else: mean, variance = moments(inputs, moments_axes, tower_config=tower_config) else: if data_format == DATA_FORMAT_NCHW: mean, variance = weighted_moments( inputs, moments_axes, batch_weights, tower_config, keep_dims=True) mean = array_ops.reshape(mean, [-1]) variance = array_ops.reshape(variance, [-1]) else: mean, variance = weighted_moments(inputs, moments_axes, batch_weights, tower_config=tower_config) moving_vars_fn = lambda: (moving_mean, moving_variance) if updates_collections is None: def _force_updates(): """Internal function forces updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) with ops.control_dependencies( [update_moving_mean, update_moving_variance]): return array_ops.identity(mean), array_ops.identity(variance) mean, variance = utils.smart_cond(is_training, _force_updates, moving_vars_fn) else: def _delay_updates(): """Internal function that delay updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) return update_moving_mean, update_moving_variance update_mean, update_variance = utils.smart_cond( is_training, _delay_updates, moving_vars_fn) ops.add_to_collections(updates_collections, update_mean) ops.add_to_collections(updates_collections, update_variance) # Use computed moments during training and moving_vars otherwise. vars_fn = lambda: (mean, variance) mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn) else: mean, variance = moving_mean, moving_variance if data_format == DATA_FORMAT_NCHW: mean = array_ops.reshape(mean, params_shape_broadcast) variance = array_ops.reshape(variance, params_shape_broadcast) if beta is not None: beta = array_ops.reshape(beta, params_shape_broadcast) if gamma is not None: gamma = array_ops.reshape(gamma, params_shape_broadcast) # Compute batch_normalization. outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon) outputs.set_shape(inputs_shape) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
def _build_statistics_second_moment(self, input_batch, reduction_indices, use_batch_stats): """Builds the statistics part of the graph when using moving second moment. Args: input_batch: Input batch Tensor. reduction_indices: Indices of `input_batch` to reduce over. use_batch_stats: Boolean to indicate if batch statistics should be calculated, otherwise moving averages are returned. Returns: Tuple of (mean, variance, second_moment). """ # Set up our moving statistics. When connecting in parallel, this is shared. self._moving_mean = tf.get_variable( "moving_mean", shape=self._mean_shape, collections=[tf.GraphKeys.MOVING_AVERAGE_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES], initializer=tf.zeros_initializer(), trainable=False) self._moving_second_moment = tf.get_variable( "moving_second_moment", shape=self._mean_shape, collections=[tf.GraphKeys.MOVING_AVERAGE_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES], initializer=tf.ones_initializer(), trainable=False) self._moving_variance = tf.subtract(self._moving_second_moment, tf.square(self._moving_mean), name="moving_variance") def build_batch_stats(): """Builds the batch statistics calculation ops.""" # Copy for better stability. # We use the moving mean as an estimate of the mean in order to perform # a more numerically stable calculation of the batch mean. shift = tf.add(self._moving_mean, 0) counts, shifted_sum_x, shifted_sum_x2, _ = tf.nn.sufficient_statistics( input_batch, reduction_indices, keep_dims=True, shift=shift, name="batch_norm_ss") mean, variance = tf.nn.normalize_moments(counts, shifted_sum_x, shifted_sum_x2, shift, name="normalize_moments") second_moment = variance + tf.square(mean) return mean, variance, second_moment def build_moving_stats(): return ( tf.identity(self._moving_mean), tf.identity(self._moving_variance), tf.identity(self._moving_second_moment), ) mean, variance, second_moment = utils.smart_cond( use_batch_stats, build_batch_stats, build_moving_stats, ) return mean, variance, second_moment
def fused_batch_norm( inputs, renorm=False, RMAX=None, DMAX=None, decay=0.999, center=True, scale=False, epsilon=0.001, activation_fn=None, param_initializers=None, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, data_format=DATA_FORMAT_NHWC, zero_debias_moving_mean=False, scope=None): """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" Sergey Ioffe, Christian Szegedy Can be used as a normalizer function for conv2d and fully_connected. Note: When is_training is True the moving_mean and moving_variance need to be updated, by default the update_ops are placed in `tf.GraphKeys.UPDATE_OPS` so they need to be added as a dependency to the `train_op`, example: update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) if update_ops: updates = tf.group(*update_ops) total_loss = control_flow_ops.with_dependencies([updates], total_loss) Args: inputs: a tensor with 2 or more dimensions, where the first dimension has `batch_size`. The normalization is over all but the last dimension if `data_format` is `NHWC` and the second dimension if `data_format` is `NCHW`. decay: decay for the moving average. Reasonable values for `decay` are close to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc. Lower `decay` value (recommend trying `decay`=0.9) if model experiences reasonably good training performance but poor validation and/or test performance. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. epsilon: small float added to variance to avoid dividing by zero. activation_fn: activation function, default set to None to skip it and maintain a linear activation. param_initializers: optional initializers for beta, gamma, moving mean and moving variance. updates_collections: collections to collect the update ops for computation. The updates_ops need to be executed with the train_op. If None, a control dependency would be added to make sure the updates are computed in place. is_training: whether or not the layer is in training mode. In training mode it would accumulate the statistics of the moments into `moving_mean` and `moving_variance` using an exponential moving average with the given `decay`. When it is not in training mode then it would use the values of the `moving_mean` and the `moving_variance`. reuse: whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. variables_collections: optional collections for the variables. outputs_collections: collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). data_format: A string. `NHWC` (default) and `NCHW` are supported. zero_debias_moving_mean: Use zero_debias for moving_mean. scope: Optional scope for `variable_scope`. Returns: A `Tensor` representing the output of the operation. Raises: ValueError: if `data_format` is neither `NHWC` nor `NCHW`. ValueError: if the rank of `inputs` is undefined. ValueError: if the rank of `inputs` is neither 2 or 4. ValueError: if rank or `C` dimension of `inputs` is undefined. """ if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') with tf.variable_scope( scope, 'BatchNorm', [inputs], reuse=reuse) as sc: inputs = ops.convert_to_tensor(inputs) original_shape = inputs.get_shape() original_rank = original_shape.ndims if original_rank is None: raise ValueError('Inputs %s has undefined rank' % inputs.name) elif original_rank not in [2, 4]: raise ValueError('Inputs %s has unsupported rank.' ' Expected 2 or 4 but got %d' % ( inputs.name, original_rank)) if original_rank == 2: channels = inputs.get_shape()[-1].value if channels is None: raise ValueError('`C` dimension must be known but is None') new_shape = [-1, 1, 1, channels] if data_format == DATA_FORMAT_NCHW: new_shape = [-1, channels, 1, 1] inputs = array_ops.reshape(inputs, new_shape) inputs_shape = inputs.get_shape() dtype = inputs.dtype.base_dtype if data_format == DATA_FORMAT_NHWC: params_shape = inputs_shape[-1:] else: params_shape = inputs_shape[1:2] if not params_shape.is_fully_defined(): raise ValueError('Inputs %s has undefined `C` dimension %s.' % (inputs.name, params_shape)) if not param_initializers: param_initializers = {} # Allocate parameters for the beta and gamma of the normalization. trainable_beta = trainable and center if trainable_beta: beta_collections = utils.get_variable_collections(variables_collections, 'beta') beta_initializer = param_initializers.get('beta', init_ops.zeros_initializer()) real_beta = variables.model_variable( 'beta', shape=params_shape, dtype=dtype, initializer=beta_initializer, collections=beta_collections, trainable=trainable_beta) beta = tf.zeros(params_shape, name='fakebeta') else: real_beta = tf.zeros(params_shape, name='beta') beta = tf.zeros(params_shape, name='fakebeta') trainable_gamma = trainable and scale if trainable_gamma: gamma_collections = utils.get_variable_collections(variables_collections, 'gamma') gamma_initializer = param_initializers.get('gamma', init_ops.ones_initializer()) gamma = variables.model_variable( 'gamma', shape=params_shape, dtype=dtype, initializer=gamma_initializer, collections=gamma_collections, trainable=trainable_gamma) else: gamma = tf.ones(params_shape, name='gamma') # Create moving_mean and moving_variance variables and add them to the # appropiate collections. moving_mean_collections = utils.get_variable_collections( variables_collections, 'moving_mean') moving_mean_initializer = param_initializers.get( 'moving_mean', init_ops.zeros_initializer()) moving_mean = variables.model_variable( 'moving_mean', shape=params_shape, dtype=dtype, initializer=moving_mean_initializer, trainable=False, collections=moving_mean_collections) moving_variance_collections = utils.get_variable_collections( variables_collections, 'moving_variance') moving_variance_initializer = param_initializers.get( 'moving_variance', init_ops.ones_initializer()) moving_variance = variables.model_variable( 'moving_variance', shape=params_shape, dtype=dtype, initializer=moving_variance_initializer, trainable=False, collections=moving_variance_collections) def _fused_batch_norm_training(): outputs, mean, variance = nn.fused_batch_norm( inputs, gamma, beta, epsilon=epsilon, data_format=data_format) if renorm: moving_inv = math_ops.rsqrt(moving_variance + epsilon) r = tf.stop_gradient(tf.clip_by_value(tf.sqrt(variance + epsilon) * moving_inv, 1/RMAX, RMAX)) d = tf.stop_gradient(tf.clip_by_value((mean - moving_mean) * moving_inv, -DMAX, DMAX)) outputs = outputs * r + d return outputs, mean, variance def _fused_batch_norm_inference(): return nn.fused_batch_norm( inputs, gamma, beta, mean=moving_mean, variance=moving_variance, epsilon=epsilon, is_training=False, data_format=data_format) outputs, mean, variance = utils.smart_cond(is_training, _fused_batch_norm_training, _fused_batch_norm_inference) outputs = tf.nn.bias_add(outputs, real_beta) # If `is_training` doesn't have a constant value, because it is a `Tensor`, # a `Variable` or `Placeholder` then is_training_value will be None and # `need_updates` will be true. is_training_value = utils.constant_value(is_training) need_updates = is_training_value is None or is_training_value if need_updates: moving_vars_fn = lambda: (moving_mean, moving_variance) def _delay_updates(): """Internal function that delay updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) return update_moving_mean, update_moving_variance update_mean, update_variance = utils.smart_cond(is_training, _delay_updates, moving_vars_fn) ops.add_to_collections(ops.GraphKeys.UPDATE_OPS, update_mean) ops.add_to_collections(ops.GraphKeys.UPDATE_OPS, update_variance) outputs.set_shape(inputs_shape) if original_shape.ndims == 2: outputs = array_ops.reshape(outputs, original_shape) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.original_name_scope, outputs)
def spectral_normalization(weights, num_iterations=1, epsilon=1e-12, u_initializer=tf.random_normal_initializer(), updates_collections=tf.GraphKeys.UPDATE_OPS, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, scope=None): with tf.variable_scope(scope, 'SpectralNorm', [weights], reuse=reuse) as sc: weights = tf.convert_to_tensor(weights) dtype = weights.dtype.base_dtype w_t = tf.reshape(weights, [-1, weights.shape.as_list()[-1]]) w = tf.transpose(w_t) m, n = w.shape.as_list() u_collections = utils.get_variable_collections(variables_collections, 'u') u = tf.get_variable( "u", shape=[m, 1], dtype=dtype, initializer=u_initializer, trainable=False, collections=u_collections, ) sigma_collections = utils.get_variable_collections( variables_collections, 'sigma') sigma = tf.get_variable('sigma', shape=[], dtype=dtype, initializer=tf.zeros_initializer(), trainable=False, collections=sigma_collections) def _power_iteration(i, u, v): v_ = tf.nn.l2_normalize(tf.matmul(w_t, u), epsilon=epsilon) u_ = tf.nn.l2_normalize(tf.matmul(w, v_), epsilon=epsilon) return i + 1, u_, v_ _, u_, v_ = tf.while_loop(cond=lambda i, _1, _2: i < num_iterations, body=_power_iteration, loop_vars=[ tf.constant(0), u, tf.zeros(shape=[n, 1], dtype=tf.float32) ]) u_ = tf.stop_gradient(u_) v_ = tf.stop_gradient(v_) sigma_ = tf.matmul(tf.transpose(u_), tf.matmul(w, v_))[0, 0] update_u = u.assign(u_) update_sigma = sigma.assign(sigma_) if updates_collections is None: def _force_update(): with tf.control_dependencies([update_u, update_sigma]): return tf.identity(sigma_) sigma_ = utils.smart_cond(is_training, _force_update, lambda: sigma) weights_sn = weights / sigma_ else: sigma_ = utils.smart_cond(is_training, lambda: sigma_, lambda: sigma) weights_sn = weights / sigma_ tf.add_to_collections(updates_collections, update_u) tf.add_to_collections(updates_collections, update_sigma) return utils.collect_named_outputs(outputs_collections, sc.name, weights_sn)
def _build_statistics(self, input_batch, axis, use_batch_stats, dtype): """Builds the statistics part of the graph when using moving variance. Args: input_batch: Input batch Tensor. axis: Indices of `input_batch` to reduce over. use_batch_stats: Boolean to indicate if batch statistics should be calculated, otherwise moving averages are returned. dtype: TensorFlow datatype to use for the moving mean and variance. Returns: Tuple of (mean, variance). """ # Set up our moving statistics. When connecting in parallel, this is shared. if self.MOVING_MEAN not in self._initializers: self._initializers[self.MOVING_MEAN] = create_mean_initializer() self._moving_mean = tf.get_variable( "moving_mean", dtype=dtype, shape=self._mean_shape, collections=[ tf.GraphKeys.MOVING_AVERAGE_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES, ], initializer=self._initializers[self.MOVING_MEAN], trainable=False) if self.MOVING_VARIANCE not in self._initializers: self._initializers[ self.MOVING_VARIANCE] = create_variance_initializer() self._moving_variance = tf.get_variable( "moving_variance", dtype=dtype, shape=self._mean_shape, collections=[ tf.GraphKeys.MOVING_AVERAGE_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES, ], initializer=self._initializers[self.MOVING_VARIANCE], trainable=False) def build_batch_stats(): """Builds the batch statistics calculation ops.""" mean, variance = tf.nn.moments(input_batch, axis, keep_dims=True, name="normalize_moments") return mean, variance def build_moving_stats(): return ( tf.identity(self._moving_mean), tf.identity(self._moving_variance), ) mean, variance = utils.smart_cond( use_batch_stats, build_batch_stats, build_moving_stats, ) return mean, variance