예제 #1
0
    def _build_update_ops_second_moment(self, mean, second_moment,
                                        is_training):
        def build_update_ops():
            update_mean_op = moving_averages.assign_moving_average(
                variable=self._moving_mean,
                value=mean,
                decay=self._decay_rate,
                name="update_moving_mean").op

            update_second_moment_op = moving_averages.assign_moving_average(
                variable=self._moving_second_moment,
                value=second_moment,
                decay=self._decay_rate,
                name="update_moving_second_moment").op

            return update_mean_op, update_second_moment_op

        def build_no_ops():
            return (tf.no_op(), tf.no_op())

        is_training_const = utils.constant_value(is_training)
        if is_training_const is None or is_training_const:
            update_mean_op, update_second_moment_op = utils.smart_cond(
                is_training,
                build_update_ops,
                build_no_ops,
            )

        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean_op)
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_second_moment_op)
예제 #2
0
 def test_value(self):
     fn1 = lambda: 'fn1'
     fn2 = lambda: 'fn2'
     expected = lambda v: 'fn1' if v else 'fn2'
     for v in [True, False, 1, 0]:
         o = utils.smart_cond(tf.constant(v), fn1, fn2)
         self.assertEqual(o, expected(v))
예제 #3
0
    def _build_update_ops_variance(self, mean, variance, is_training):
        def build_update_ops():
            update_mean_op = moving_averages.assign_moving_average(
                variable=self._moving_mean,
                value=mean,
                decay=self._decay_rate,
                name="update_moving_mean").op

            update_variance_op = moving_averages.assign_moving_average(
                variable=self._moving_variance,
                value=variance,
                decay=self._decay_rate,
                name="update_moving_variance").op

            return update_mean_op, update_variance_op

        def build_no_ops():
            return (tf.no_op(), tf.no_op())

            # Only make the ops if we know that `is_training=True`, or the
            # value of `is_training` is unknown.

        is_training_const = utils.constant_value(is_training)
        if is_training_const is None or is_training_const:
            update_mean_op, update_variance_op = utils.smart_cond(
                is_training,
                build_update_ops,
                build_no_ops,
            )

        # Every new connection creates a new op which adds its contribution
        # to the running average when ran.
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean_op)
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_variance_op)
예제 #4
0
 def test_value(self):
   fn1 = lambda: 'fn1'
   fn2 = lambda: 'fn2'
   expected = lambda v: 'fn1' if v else 'fn2'
   for v in [True, False, 1, 0]:
     o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
     self.assertEqual(o, expected(v))
예제 #5
0
    def _update_renorm_variable(var, weight, value):
        """Updates a moving average and weight, returns the unbiased value."""
        value = array_ops.identity(value)

        def _do_update():
            # Update the variables without zero debiasing. The debiasing will be
            # accomplished by dividing the exponential moving average by the weight.
            # For example, after a single update, the moving average would be
            # (1-decay) * value. and the weight will be 1-decay, with their ratio
            # giving the value.
            # Make sure the weight is not updated until before r and d computation.
            with ops.control_dependencies([value]):
                weight_value = array_ops.constant(1., dtype=weight.dtype)
            new_var = moving_averages.assign_moving_average(
                var, value, renorm_params.renorm_momentum, zero_debias=False)
            new_weight = moving_averages.assign_moving_average(
                weight,
                weight_value,
                renorm_params.renorm_momentum,
                zero_debias=False)
            return new_var / new_weight

        def _fake_update():
            return array_ops.identity(var)

        return utils.smart_cond(training, _do_update, _fake_update)
예제 #6
0
def batch_norm(input, gammas, betas, epsilon, is_training):
    """

    BatchNorm implementation with sample-specific beta and gamma parameters
    i.e. the shift and scaling parameters are different across a batch of examples

    :param input: feature map input. 3-D vector (+ batch size)
    :param gammas: BN gamma parameters. 1-D vector (+ batch size)
    :param betas: BN betas parameters. 1-D vector (+ batch size)
    :param epsilon: BN epsilon for stability
    :param is_training: compute (True) or use (False) moving mean and variance
    :return: input after BN
    """

    assert (len(input.get_shape()) == 4)
    num_channels = int(input.get_shape()[3])

    # use cbn input score to not initialize the variable with resnet values
    with tf.variable_scope("cbn_input"):
        moving_mean = tf.get_variable("moving_mean", [num_channels],
                                      dtype=tf.float32,
                                      trainable=False)
        moving_variance = tf.get_variable("moving_variance", [num_channels],
                                          dtype=tf.float32,
                                          trainable=False)

    def _training():
        """
        Internal function that delay updates moving_vars if is_training.
        """
        mean, variance = tf.nn.moments(input, [0, 1, 2])

        update_moving_mean = moving_averages.assign_moving_average(
            moving_mean, mean, 0.99, zero_debias=True)
        update_moving_variance = moving_averages.assign_moving_average(
            moving_variance, variance, 0.99, zero_debias=False)

        return mean, variance, update_moving_mean, update_moving_variance

    def _inference():
        return moving_mean, moving_variance, moving_mean, moving_variance

    # Collect mean/variance to prepare moving mean/variance
    means, variances, update_mean, update_variance = tf_utils.smart_cond(
        is_training, _training, _inference)

    # Add moving mean/variance to tue update_ops (cf tensorflow batchnorm documentation)
    updates_collections = ops.GraphKeys.UPDATE_OPS
    ops.add_to_collections(updates_collections, update_mean)
    ops.add_to_collections(updates_collections, update_variance)

    # apply batch norm
    inv = gammas * tf.expand_dims(tf.rsqrt(variances + epsilon), 0)
    expanded_inv = tf.reshape(inv, [-1, 1, 1, num_channels])
    expanded_mean = tf.reshape(means, [-1, 1, 1, num_channels])
    expanded_betas = tf.reshape(betas, [-1, 1, 1, num_channels])
    out = expanded_inv * (input - expanded_mean) + expanded_betas

    return out
예제 #7
0
def dropout(x, keep_prob=0.5, is_training=False, name='drop'):
    with tf.variable_scope(name, 'dropout', [x]) as sc:
        x = utils.smart_cond(is_training,
                             lambda: tf.nn.dropout(x, keep_prob, name=name),
                             lambda: x,
                             name=name)
        return utils.collect_named_outputs(tf.GraphKeys.ACTIVATIONS,
                                           sc.original_name_scope, x)
예제 #8
0
 def test_constant(self):
   fn1 = lambda: constant_op.constant('fn1')
   fn2 = lambda: constant_op.constant('fn2')
   expected = lambda v: b'fn1' if v else b'fn2'
   for v in [True, False, 1, 0]:
     o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
     with self.test_session():
       self.assertEqual(o.eval(), expected(v))
예제 #9
0
 def test_tensors(self):
   fn1 = lambda: constant_op.constant(0) - constant_op.constant(1)
   fn2 = lambda: constant_op.constant(0) - constant_op.constant(2)
   expected = lambda v: -1 if v else -2
   for v in [True, False, 1, 0]:
     o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
     with self.test_session():
       self.assertEqual(o.eval(), expected(v))
예제 #10
0
 def test_constant(self):
     fn1 = lambda: tf.constant('fn1')
     fn2 = lambda: tf.constant('fn2')
     expected = lambda v: b'fn1' if v else b'fn2'
     for v in [True, False, 1, 0]:
         o = utils.smart_cond(tf.constant(v), fn1, fn2)
         with self.test_session():
             self.assertEqual(o.eval(), expected(v))
예제 #11
0
 def test_tensors(self):
     fn1 = lambda: tf.constant(0) - tf.constant(1)
     fn2 = lambda: tf.constant(0) - tf.constant(2)
     expected = lambda v: -1 if v else -2
     for v in [True, False, 1, 0]:
         o = utils.smart_cond(tf.constant(v), fn1, fn2)
         with self.test_session():
             self.assertEqual(o.eval(), expected(v))
예제 #12
0
    def _create_softmax_layer(self, proj, dec_outs, targets, weights,
                              scope_name, args):
        with tf.variable_scope(scope_name):
            w_t, b = proj

            # is_training = Flase
            def get_llh_test():
                dec_outs_flt = tf.reshape(dec_outs, [-1, args.num_units])
                logits_flt = tf.add(
                    tf.matmul(dec_outs_flt, w_t, transpose_b=True), b[None, :])
                logits = tf.reshape(
                    logits_flt, [tf.shape(dec_outs)[0], -1, args.vocab_size])

                llh_precise = tf.contrib.seq2seq.sequence_loss(
                    logits=logits,
                    targets=targets,
                    weights=weights,
                    average_across_timesteps=False,
                    average_across_batch=False,
                    softmax_loss_function=None)
                return llh_precise

            # is_training = True
            def sampled_loss(inputs, labels):
                labels = tf.reshape(labels, [-1, 1])
                # use 32bit float to avoid numerical instabilites
                #w_t = tf.transpose(w)
                local_w_t = tf.cast(w_t, tf.float32)
                local_b = tf.cast(b, tf.float32)
                local_inputs = tf.cast(inputs, tf.float32)
                return tf.nn.sampled_softmax_loss(weights=local_w_t,
                                                  biases=local_b,
                                                  inputs=local_inputs,
                                                  labels=labels,
                                                  num_sampled=args.num_samples,
                                                  num_classes=args.vocab_size,
                                                  partition_strategy="div")

            # is_training = False
            def get_llh_train():
                # if use sampled_softmax
                if args.use_sampled_softmax and args.num_samples > 0 and args.num_samples < args.vocab_size:
                    llh_train = tf.contrib.seq2seq.sequence_loss(
                        logits=dec_outs,
                        targets=targets,
                        weights=weights,
                        average_across_timesteps=False,
                        average_across_batch=False,
                        softmax_loss_function=sampled_loss)
                    self._logger.info('Use sampled softmax during training')
                else:
                    llh_train = get_llh_test()
                    self._logger.info('Use precise softmax during training')
                return llh_train

            loss = smart_cond(self.is_training_plh, get_llh_train,
                              get_llh_test)
        return loss
예제 #13
0
 def test_tensors(self):
     fn1 = lambda: tf.constant(0) - tf.constant(1)
     fn2 = lambda: tf.constant(0) - tf.constant(2)
     expected = lambda v: -1 if v else -2
     p = tf.placeholder(tf.bool, [])
     for v in [True, False, 1, 0]:
         o = utils.smart_cond(p, fn1, fn2)
         with self.test_session():
             self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
예제 #14
0
 def test_variable(self):
   fn1 = lambda: variables.Variable('fn1')
   fn2 = lambda: variables.Variable('fn2')
   expected = lambda v: b'fn1' if v else b'fn2'
   for v in [True, False, 1, 0]:
     o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
     with self.test_session() as sess:
       sess.run(variables.global_variables_initializer())
       self.assertEqual(o.eval(), expected(v))
예제 #15
0
 def test_value(self):
   fn1 = lambda: ops.convert_to_tensor('fn1')
   fn2 = lambda: ops.convert_to_tensor('fn2')
   expected = lambda v: b'fn1' if v else b'fn2'
   p = array_ops.placeholder(dtypes.bool, [])
   for v in [True, False, 1, 0]:
     o = utils.smart_cond(p, fn1, fn2)
     with self.test_session():
       self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
예제 #16
0
 def test_variable(self):
     fn1 = lambda: tf.Variable('fn1')
     fn2 = lambda: tf.Variable('fn2')
     expected = lambda v: b'fn1' if v else b'fn2'
     for v in [True, False, 1, 0]:
         o = utils.smart_cond(tf.constant(v), fn1, fn2)
         with self.test_session() as sess:
             sess.run(tf.global_variables_initializer())
             self.assertEqual(o.eval(), expected(v))
예제 #17
0
 def test_value(self):
   fn1 = lambda: ops.convert_to_tensor('fn1')
   fn2 = lambda: ops.convert_to_tensor('fn2')
   expected = lambda v: b'fn1' if v else b'fn2'
   p = array_ops.placeholder(dtypes.bool, [])
   for v in [True, False, 1, 0]:
     o = utils.smart_cond(p, fn1, fn2)
     with self.test_session():
       self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
예제 #18
0
 def test_constant(self):
   fn1 = lambda: tf.constant('fn1')
   fn2 = lambda: tf.constant('fn2')
   expected = lambda v: b'fn1' if v else b'fn2'
   p = tf.placeholder(tf.bool, [])
   for v in [True, False, 1, 0]:
     o = utils.smart_cond(p, fn1, fn2)
     with self.test_session():
       self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
예제 #19
0
 def test_tensors(self):
   fn1 = lambda: constant_op.constant(0) - constant_op.constant(1)
   fn2 = lambda: constant_op.constant(0) - constant_op.constant(2)
   expected = lambda v: -1 if v else -2
   p = array_ops.placeholder(dtypes.bool, [])
   for v in [True, False, 1, 0]:
     o = utils.smart_cond(p, fn1, fn2)
     with self.test_session():
       self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
예제 #20
0
 def test_constant(self):
     fn1 = lambda: tf.constant('fn1')
     fn2 = lambda: tf.constant('fn2')
     expected = lambda v: b'fn1' if v else b'fn2'
     p = tf.placeholder(tf.bool, [])
     for v in [True, False, 1, 0]:
         o = utils.smart_cond(p, fn1, fn2)
         with self.test_session():
             self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
예제 #21
0
    def _build_statistics_second_moment(self, input_batch, reduction_indices,
                                        use_batch_stats):
        self._moving_mean = tf.get_variable(
            "moving_mean",
            shape=self._mean_shape,
            collections=[
                tf.GraphKeys.MOVING_AVERAGE_VARIABLES, tf.GraphKeys.VARIABLES
            ],
            initializer=tf.zeros_initializer,
            trainable=False)

        self._moving_second_moment = tf.get_variable(
            "moving_second_moment",
            shape=self._mean_shape,
            collections=[
                tf.GraphKeys.MOVING_AVERAGE_VARIABLES, tf.GraphKeys.VARIABLES
            ],
            initializer=tf.ones_initializer(),
            trainable=False)

        self._moving_variance = tf.sub(self._moving_second_moment,
                                       tf.square(self._moving_mean),
                                       name="moving_variance")

        def build_batch_stats():
            shift = tf.add(self._moving_mean, 0)
            counts, shifted_sum_x, shifted_sum_x2, _ = tf.nn.sufficient_statistics(
                input_batch,
                reduction_indices,
                keep_dims=True,
                shift=shift,
                name="batch_norm_ss")

            mean, variance = tf.nn.normalize_moments(counts,
                                                     shifted_sum_x,
                                                     shifted_sum_x2,
                                                     shift,
                                                     name="normalize_moments")
            second_moment = variance + tf.square(mean)

            return mean, variance, second_moment

        def build_moving_stats():
            return (
                tf.identity(self._moving_mean),
                tf.identity(self._moving_variance),
                tf.identity(self._moving_second_moment),
            )

        mean, variance, second_moment = utils.smart_cond(
            use_batch_stats,
            build_batch_stats,
            build_moving_stats,
        )

        return mean, variance, second_moment
예제 #22
0
 def test_variable(self):
   fn1 = lambda: variables.Variable('fn1')
   fn2 = lambda: variables.Variable('fn2')
   expected = lambda v: b'fn1' if v else b'fn2'
   p = array_ops.placeholder(dtypes.bool, [])
   for v in [True, False, 1, 0]:
     o = utils.smart_cond(p, fn1, fn2)
     with self.test_session() as sess:
       sess.run(variables.global_variables_initializer())
       self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
예제 #23
0
 def test_variable(self):
     fn1 = lambda: tf.Variable('fn1')
     fn2 = lambda: tf.Variable('fn2')
     expected = lambda v: b'fn1' if v else b'fn2'
     p = tf.placeholder(tf.bool, [])
     for v in [True, False, 1, 0]:
         o = utils.smart_cond(p, fn1, fn2)
         with self.test_session() as sess:
             sess.run(tf.global_variables_initializer())
             self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
예제 #24
0
 def batch_normalization(x):
     beta = tf.get_variable('beta', [x.get_shape()[-1]], dtype, tf.zeros_initializer())
     gamma = tf.get_variable('gamma', [x.get_shape()[-1]], dtype, tf.ones_initializer())
     mv_mean = tf.get_variable('mv_mean', [x.get_shape()[-1]], dtype=dtype, initializer=tf.zeros_initializer(), trainable=False)
     mv_var = tf.get_variable('mv_var', [x.get_shape()[-1]], dtype=dtype, initializer=tf.ones_initializer(), trainable=False)
     mean, variance = tf.nn.moments(x, [0], name='moments')
     tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, assign_moving_average(mv_mean, mean, decay, zero_debias=True))
     tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, assign_moving_average(mv_var, variance, decay, zero_debias=False))
     mean, variance = utils.smart_cond(is_training, lambda: (mean, variance), lambda: (mv_mean, mv_var))
     return tf.nn.batch_normalization(x, mean, variance, beta, gamma, 1e-6)
예제 #25
0
 def test_variable(self):
   fn1 = lambda: tf.Variable('fn1')
   fn2 = lambda: tf.Variable('fn2')
   expected = lambda v: b'fn1' if v else b'fn2'
   p = tf.placeholder(tf.bool, [])
   for v in [True, False, 1, 0]:
     o = utils.smart_cond(p, fn1, fn2)
     with self.test_session() as sess:
       sess.run(tf.initialize_all_variables())
       self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
예제 #26
0
파일: semireg.py 프로젝트: wead-hsu/sssp
    def model_setup(self, args):
        with tf.variable_scope(args.log_prefix):
            self.init_global_step()
            self._create_placeholders()
            self._logger.info("Created placeholders.")
            self._create_embedding_matrix(args)

            self.kl_w = tf.log(1. + tf.exp((self.global_step - args.klw_b) * args.klw_w))
            self.kl_w = tf.minimum(self.kl_w, 1.) / 100.0 #scale reweighted
        
        self.loss_l = self.get_loss_l(args)
        self.train_unlabel = tf.greater(self.global_step, args.num_pretrain_steps)
        self.loss_u = smart_cond(self.train_unlabel, lambda: self.get_loss_u(args), lambda: tf.constant(0.))
        tf.summary.scalar('train_unlabel', tf.to_int64(self.train_unlabel))
        tf.summary.scalar('loss_u', self.loss_u)

        self.loss = self.loss_l + self.loss_u
        tf.summary.scalar('loss', self.loss)

        with tf.variable_scope(args.log_prefix):
            # optimizer
            #embd_var = self.embedding_matrix
            #other_var_list = [v for v in tf.trainable_variables() if v.name != embd_var.name]
            learning_rate = tf.train.exponential_decay(args.learning_rate, self.global_step, 
                    args.decay_steps,
                    args.decay_rate,
                    staircase=True)
            self.train_op = self.training_op(self.loss, #tf.trainable_variables(),
                    tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES, scope=args.log_prefix),
                    grad_clip=args.grad_clip,
                    max_norm=args.max_norm,
                    train_embd=True,
                    learning_rate=args.learning_rate,)
            self._logger.info("Created SemiClassifier Model.")

            self._create_saver(args)
            self._logger.info('Created Saver')

            self.merged = tf.summary.merge_all()

            """ Create beam search layer
            self.beam_output_cur, self.beam_scores_cur = self._create_beam_search_layer(
                    init_state=yz,
                    dec_step_func=cur_dec_func,
                    cell=cur_cell,
                    embedding_matrix=self.embedding_matrix,
                    vocab_size=args.vocab_size,
                    num_layers=args.num_layers,)
            self._logger.info('Created Beam Search Layer')
            """

            vt = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES, scope=args.log_prefix)
            vs = tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES, scope=args.log_prefix)
        return vt, vs
예제 #27
0
def dropout_selu(x,
                 rate,
                 alpha=-1.7580993408473766,
                 fixedPointMean=0.0,
                 fixedPointVar=1.0,
                 noise_shape=None,
                 seed=None,
                 name=None,
                 training=False):
    """Dropout to a value with rescaling."""
    def dropout_selu_impl(x, rate, alpha, noise_shape, seed, name):
        keep_prob = 1.0 - rate
        x = ops.convert_to_tensor(x, name="x")
        if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1:
            raise ValueError(
                "keep_prob must be a scalar tensor or a float in the "
                "range (0, 1], got %g" % keep_prob)
        keep_prob = ops.convert_to_tensor(keep_prob,
                                          dtype=x.dtype,
                                          name="keep_prob")
        keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())

        alpha = ops.convert_to_tensor(alpha, dtype=x.dtype, name="alpha")
        alpha.get_shape().assert_is_compatible_with(tensor_shape.scalar())

        if tensor_util.constant_value(keep_prob) == 1:
            return x

        noise_shape = noise_shape if noise_shape is not None else array_ops.shape(
            x)
        random_tensor = keep_prob
        random_tensor += random_ops.random_uniform(noise_shape,
                                                   seed=seed,
                                                   dtype=x.dtype)
        binary_tensor = math_ops.floor(random_tensor)
        ret = x * binary_tensor + alpha * (1 - binary_tensor)

        a = math_ops.sqrt(
            fixedPointVar /
            (keep_prob *
             ((1 - keep_prob) * math_ops.pow(alpha - fixedPointMean, 2) +
              fixedPointVar)))

        b = fixedPointMean - a * (keep_prob * fixedPointMean +
                                  (1 - keep_prob) * alpha)
        ret = a * ret + b
        ret.set_shape(x.get_shape())
        return ret

    with ops.name_scope(name, "dropout", [x]) as name:
        return utils.smart_cond(
            training,
            lambda: dropout_selu_impl(x, rate, alpha, noise_shape, seed, name),
            lambda: array_ops.identity(x))
예제 #28
0
파일: semireg.py 프로젝트: wead-hsu/sssp
    def _get_elbo_label(self, inp, tgt, msk, y, args):
        """ Build encoder and decoders """
        xlen = tf.to_int32(tf.reduce_sum(msk, axis=1))
        enc_state = self._create_encoder(
                tgt,
                seqlen=xlen,
                scope_name='enc',
                args=args)

        with tf.variable_scope('latent'):
            y_enc_in = tf.contrib.layers.fully_connected(y, args.dim_z, scope='y_enc_in')
            pst_in = tf.concat([y_enc_in, enc_state], axis=1)
            mu_pst = tf.contrib.layers.fully_connected(pst_in, args.dim_z, tf.nn.tanh,
                    scope='mu_posterior')
            logvar_pst = tf.contrib.layers.fully_connected(pst_in, args.dim_z, tf.nn.tanh,
                    scope='logvar_posterior')
            mu_pri = tf.zeros_like(mu_pst)
            logvar_pri = tf.ones_like(logvar_pst)
            dist_pri = tf.contrib.distributions.Normal(mu=mu_pri, sigma=tf.exp(logvar_pri))
            dist_pst = tf.contrib.distributions.Normal(mu=mu_pst, sigma=tf.exp(logvar_pst))
            kl_loss = tf.contrib.distributions.kl(dist_pst, dist_pri)
            kl_loss = tf.reduce_sum(kl_loss, axis=1)

        with st.value_type(st.SampleValue(stop_gradient=False)):
            z_st_pri = st.StochasticTensor(dist_pri, name='z_pri')
            z_st_pst = st.StochasticTensor(dist_pst, name='z_pst')
            z = smart_cond(self.is_training, lambda: z_st_pst, lambda: z_st_pri)
       
        z_ext = tf.contrib.layers.fully_connected(tf.reshape(z, [-1, args.dim_z]), args.num_units, scope='extend_z')
        xlen = tf.to_int32(tf.reduce_sum(msk, axis=1))
        outs, proj, dec_func, cell  = self._create_decoder(
                inp,
                seqlen=xlen,
                label_oh=y,
                init_state=z_ext,
                scope_name='dec',
                args=args)

        # build loss layers
        recons_loss = self._create_softmax_layer(
                proj=proj,
                dec_outs=outs,
                targets=tgt,
                weights=msk,
                scope_name='loss',
                args=args)
        
        return recons_loss, kl_loss
예제 #29
0
    def _build_update_ops(self, mean, variance, is_training):
        """Builds the moving average update ops when using moving variance.

    Args:
      mean: The mean value to update with.
      variance: The variance value to update with.
      is_training: Boolean Tensor to indicate if we're currently in
        training mode.

    Returns:
      Tuple of `(update_mean_op, update_variance_op)` when `is_training` is or
      could be `True`. Returns `None` when `is_training=False`.
    """
        def build_update_ops():
            """Builds the exponential moving average update ops."""

            update_mean_op = moving_averages.assign_moving_average(
                variable=self._moving_mean,
                value=mean,
                decay=self._decay_rate,
                zero_debias=False,
                name="update_moving_mean").op

            update_variance_op = moving_averages.assign_moving_average(
                variable=self._moving_variance,
                value=variance,
                decay=self._decay_rate,
                zero_debias=False,
                name="update_moving_variance").op

            return update_mean_op, update_variance_op

        def build_no_ops():
            return (tf.no_op(), tf.no_op())

        # Only make the ops if we know that `is_training=True`, or the value of
        # `is_training` is unknown.
        is_training_const = utils.constant_value(is_training)
        if is_training_const is None or is_training_const:
            update_mean_op, update_variance_op = utils.smart_cond(
                is_training,
                build_update_ops,
                build_no_ops,
            )
            return (update_mean_op, update_variance_op)
        else:
            return None
예제 #30
0
    def _build_update_ops_second_moment(self, mean, second_moment,
                                        is_training):
        """Builds the moving average update ops when using the moving second moment.

    Args:
      mean: The mean value to update with.
      second_moment: The second_moment value to update with.
      is_training: Boolean Tensor to indicate if we're currently in
        training mode.
    """
        def build_update_ops():
            """Builds the exponential moving average update ops."""

            update_mean_op = moving_averages.assign_moving_average(
                variable=self._moving_mean,
                value=mean,
                decay=self._decay_rate,
                name="update_moving_mean").op

            update_second_moment_op = moving_averages.assign_moving_average(
                variable=self._moving_second_moment,
                value=second_moment,
                decay=self._decay_rate,
                name="update_moving_second_moment").op

            return update_mean_op, update_second_moment_op

        def build_no_ops():
            return (tf.no_op(), tf.no_op())

        # Only make the ops if we know that `is_training=True`, or the value of
        # `is_training` is unknown.
        is_training_const = utils.constant_value(is_training)
        if is_training_const is None or is_training_const:
            update_mean_op, update_second_moment_op = utils.smart_cond(
                is_training,
                build_update_ops,
                build_no_ops,
            )

            # Every new connection creates a new op which adds its contribution
            # to the running average when ran.
            tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean_op)
            tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,
                                 update_second_moment_op)
예제 #31
0
    def _fused_batch_norm_op(self, input_batch, mean, variance,
                             use_batch_stats):
        """Creates a fused batch normalization op."""
        # The fused batch norm expects the mean, variance, gamma and beta
        # tensors to have dimension 1, so we flatten them to remove the
        # extra dimensions.
        gamma_flatten = tf.reshape(self._gamma, shape=(-1, ))
        beta_flatten = tf.reshape(self._beta, shape=(-1, ))
        flatten_mean = tf.reshape(mean, shape=(-1, ))
        flatten_variance = tf.reshape(variance, shape=(-1, ))
        use_batch_stats = tf.convert_to_tensor(use_batch_stats)

        common_args = {
            "scale": gamma_flatten,
            "offset": beta_flatten,
            "epsilon": self._eps,
            "data_format": self._infer_fused_data_format(input_batch),
            "name": "batch_norm"
        }

        def use_batch_stats_fused_batch_norm():
            return tf.nn.fused_batch_norm(input_batch,
                                          mean=None,
                                          variance=None,
                                          is_training=True,
                                          **common_args)

        def moving_average_fused_batch_norm():
            return tf.nn.fused_batch_norm(input_batch,
                                          mean=flatten_mean,
                                          variance=flatten_variance,
                                          is_training=False,
                                          **common_args)

        batch_norm_op, mean, variance = utils.smart_cond(
            use_batch_stats, use_batch_stats_fused_batch_norm,
            moving_average_fused_batch_norm)

        return batch_norm_op, mean, variance
예제 #32
0
파일: batch_norm.py 프로젝트: mkabra/poseTF
def batch_norm_mine_old(inputs,
               decay=0.999,
               center=True,
               scale=False,
               epsilon=0.001,
               activation_fn=None,
               param_initializers=None,
               param_regularizers=None,
               updates_collections=ops.GraphKeys.UPDATE_OPS,
               is_training=True,
               reuse=None,
               variables_collections=None,
               outputs_collections=None,
               trainable=True,
               batch_weights=None,
               fused=False,
               data_format=DATA_FORMAT_NHWC,
               zero_debias_moving_mean=False,
               scope=None,
               renorm=False,
               renorm_clipping=None,
               renorm_decay=0.99):
  """
  This earlier version of my modification to batch norm uses
current_mean and current_variance if is_training is True and
moving_mean and moving_variance otherwise. This was leading a large divergence between
the results depending upon whether the is_training set to True or not.

I think ideally it should always use moving_mean and moving_variance. batch_norm_mine
does this.

  Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
copy of tensorflow.contrib.layers
  Args:
    inputs: A tensor with 2 or more dimensions, where the first dimension has
      `batch_size`. The normalization is over all but the last dimension if
      `data_format` is `NHWC` and the second dimension if `data_format` is
      `NCHW`.
    decay: Decay for the moving average. Reasonable values for `decay` are close
      to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc.
      Lower `decay` value (recommend trying `decay`=0.9) if model experiences
      reasonably good training performance but poor validation and/or test
      performance. Try zero_debias_moving_mean=True for improved stability.
    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
      is ignored.
    scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `nn.relu`), this can be
      disabled since the scaling can be done by the next layer.
    epsilon: Small float added to variance to avoid dividing by zero.
    activation_fn: Activation function, default set to None to skip it and
      maintain a linear activation.
    param_initializers: Optional initializers for beta, gamma, moving mean and
      moving variance.
    param_regularizers: Optional regularizer for beta and gamma.
    updates_collections: Collections to collect the update ops for computation.
      The updates_ops need to be executed with the train_op.
      If None, a control dependency would be added to make sure the updates are
      computed in place.
    is_training: Whether or not the layer is in training mode. In training mode
      it would accumulate the statistics of the moments into `moving_mean` and
      `moving_variance` using an exponential moving average with the given
      `decay`. When it is not in training mode then it would use the values of
      the `moving_mean` and the `moving_variance`.
    reuse: Whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    variables_collections: Optional collections for the variables.
    outputs_collections: Collections to add the outputs.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    batch_weights: An optional tensor of shape `[batch_size]`,
      containing a frequency weight for each batch item. If present,
      then the batch normalization uses weighted mean and
      variance. (This can be used to correct for bias in training
      example selection.)
    fused:  Use nn.fused_batch_norm if True, nn.batch_normalization otherwise.
    data_format: A string. `NHWC` (default) and `NCHW` are supported.
    zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new
      pair of variables 'moving_mean/biased' and 'moving_mean/local_step'.
    scope: Optional scope for `variable_scope`.
    renorm: Whether to use Batch Renormalization
      (https://arxiv.org/abs/1702.03275). This adds extra variables during
      training. The inference is the same for either value of this parameter.
    renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
      scalar `Tensors` used to clip the renorm correction. The correction
      `(r, d)` is used as `corrected_value = normalized_value * r + d`, with
      `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
      dmax are set to inf, 0, inf, respectively.
    renorm_decay: Momentum used to update the moving means and standard
      deviations with renorm. Unlike `momentum`, this affects training
      and should be neither too small (which would add noise) nor too large
      (which would give stale estimates). Note that `decay` is still applied
      to get the means and variances for inference.

  Returns:
    A `Tensor` representing the output of the operation.

  Raises:
    ValueError: If `batch_weights` is not None and `fused` is True.
    ValueError: If `param_regularizers` is not None and `fused` is True.
    ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
    ValueError: If the rank of `inputs` is undefined.
    ValueError: If rank or channels dimension of `inputs` is undefined.
  """
  if fused:
    if batch_weights is not None:
      raise ValueError('Weighted mean and variance is not currently '
                       'supported for fused batch norm.')
    if param_regularizers is not None:
      raise ValueError('Regularizers are not currently '
                       'supported for fused batch norm.')
    if renorm:
      raise ValueError('Renorm is not supported for fused batch norm.')
    return _fused_batch_norm(
        inputs,
        decay=decay,
        center=center,
        scale=scale,
        epsilon=epsilon,
        activation_fn=activation_fn,
        param_initializers=param_initializers,
        updates_collections=updates_collections,
        is_training=is_training,
        reuse=reuse,
        variables_collections=variables_collections,
        outputs_collections=outputs_collections,
        trainable=trainable,
        data_format=data_format,
        zero_debias_moving_mean=zero_debias_moving_mean,
        scope=scope)

  if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
    raise ValueError('data_format has to be either NCHW or NHWC.')

  layer_variable_getter = _build_variable_getter()
  with variable_scope.variable_scope(
      scope, 'BatchNorm', [inputs], reuse=reuse,
      custom_getter=layer_variable_getter) as sc:
    inputs = ops.convert_to_tensor(inputs)

    # Determine whether we can use the core layer class.
    if (batch_weights is None and
        updates_collections is ops.GraphKeys.UPDATE_OPS and
        not zero_debias_moving_mean):
      # Use the core layer class.
      axis = 1 if data_format == DATA_FORMAT_NCHW else -1
      if not param_initializers:
        param_initializers = {}
      beta_initializer = param_initializers.get('beta',
                                                init_ops.zeros_initializer())
      gamma_initializer = param_initializers.get('gamma',
                                                 init_ops.ones_initializer())
      moving_mean_initializer = param_initializers.get(
          'moving_mean', init_ops.zeros_initializer())
      moving_variance_initializer = param_initializers.get(
          'moving_variance', init_ops.ones_initializer())
      if not param_regularizers:
        param_regularizers = {}
      beta_regularizer = param_regularizers.get('beta')
      gamma_regularizer = param_regularizers.get('gamma')
      layer = normalization_layers.BatchNormalization(
          axis=axis,
          momentum=decay,
          epsilon=epsilon,
          center=center,
          scale=scale,
          beta_initializer=beta_initializer,
          gamma_initializer=gamma_initializer,
          moving_mean_initializer=moving_mean_initializer,
          moving_variance_initializer=moving_variance_initializer,
          beta_regularizer=beta_regularizer,
          gamma_regularizer=gamma_regularizer,
          trainable=trainable,
          renorm=renorm,
          renorm_clipping=renorm_clipping,
          renorm_momentum=renorm_decay,
          name=sc.name,
          _scope=sc,
          _reuse=reuse)
      outputs = layer.apply(inputs, training=is_training)

      # Add variables to collections.
      _add_variable_to_collections(
          layer.moving_mean, variables_collections, 'moving_mean')
      _add_variable_to_collections(
          layer.moving_variance, variables_collections, 'moving_variance')
      if layer.beta:
        _add_variable_to_collections(layer.beta, variables_collections, 'beta')
      if layer.gamma:
        _add_variable_to_collections(
            layer.gamma, variables_collections, 'gamma')

      if activation_fn is not None:
        outputs = activation_fn(outputs)
      return utils.collect_named_outputs(outputs_collections,
                                         sc.original_name_scope, outputs)

    # Not supported by layer class: batch_weights argument,
    # and custom updates_collections. In that case, use the legacy BN
    # implementation.
    # Custom updates collections are not supported because the update logic
    # is different in this case, in particular w.r.t. "forced updates" and
    # update op reuse.
    if renorm:
      raise ValueError('renorm is not supported with batch_weights, '
                       'updates_collections or zero_debias_moving_mean')
    inputs_shape = inputs.get_shape()
    inputs_rank = inputs_shape.ndims
    if inputs_rank is None:
      raise ValueError('Inputs %s has undefined rank.' % inputs.name)
    dtype = inputs.dtype.base_dtype
    if batch_weights is not None:
      batch_weights = ops.convert_to_tensor(batch_weights)
      inputs_shape[0:1].assert_is_compatible_with(batch_weights.get_shape())
      # Reshape batch weight values so they broadcast across inputs.
      nshape = [-1] + [1 for _ in range(inputs_rank - 1)]
      batch_weights = array_ops.reshape(batch_weights, nshape)

    if data_format == DATA_FORMAT_NCHW:
      moments_axes = [0] + list(range(2, inputs_rank))
      params_shape = inputs_shape[1:2]
      # For NCHW format, rather than relying on implicit broadcasting, we
      # explicitly reshape the params to params_shape_broadcast when computing
      # the moments and the batch normalization.
      params_shape_broadcast = list(
          [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)])
    else:
      moments_axes = list(range(inputs_rank - 1))
      params_shape = inputs_shape[-1:]
      params_shape_broadcast = None
    if not params_shape.is_fully_defined():
      raise ValueError('Inputs %s has undefined channels dimension %s.' % (
          inputs.name, params_shape))

    # Allocate parameters for the beta and gamma of the normalization.
    beta, gamma = None, None
    if not param_initializers:
      param_initializers = {}
    if center:
      beta_collections = utils.get_variable_collections(variables_collections,
                                                        'beta')
      beta_initializer = param_initializers.get('beta',
                                                init_ops.zeros_initializer())
      beta = variables.model_variable('beta',
                                      shape=params_shape,
                                      dtype=dtype,
                                      initializer=beta_initializer,
                                      collections=beta_collections,
                                      trainable=trainable)
    if scale:
      gamma_collections = utils.get_variable_collections(variables_collections,
                                                         'gamma')
      gamma_initializer = param_initializers.get('gamma',
                                                 init_ops.ones_initializer())
      gamma = variables.model_variable('gamma',
                                       shape=params_shape,
                                       dtype=dtype,
                                       initializer=gamma_initializer,
                                       collections=gamma_collections,
                                       trainable=trainable)

    # Create moving_mean and moving_variance variables and add them to the
    # appropriate collections. We disable variable partitioning while creating
    # them, because assign_moving_average is not yet supported for partitioned
    # variables.
    partitioner = variable_scope.get_variable_scope().partitioner
    try:
      variable_scope.get_variable_scope().set_partitioner(None)
      moving_mean_collections = utils.get_variable_collections(
          variables_collections, 'moving_mean')
      moving_mean_initializer = param_initializers.get(
          'moving_mean', init_ops.zeros_initializer())
      moving_mean = variables.model_variable(
          'moving_mean',
          shape=params_shape,
          dtype=dtype,
          initializer=moving_mean_initializer,
          trainable=False,
          collections=moving_mean_collections)
      moving_variance_collections = utils.get_variable_collections(
          variables_collections, 'moving_variance')
      moving_variance_initializer = param_initializers.get(
          'moving_variance', init_ops.ones_initializer())
      moving_variance = variables.model_variable(
          'moving_variance',
          shape=params_shape,
          dtype=dtype,
          initializer=moving_variance_initializer,
          trainable=False,
          collections=moving_variance_collections)
    finally:
      variable_scope.get_variable_scope().set_partitioner(partitioner)

    # If `is_training` doesn't have a constant value, because it is a `Tensor`,
    # a `Variable` or `Placeholder` then is_training_value will be None and
    # `needs_moments` will be true.
    is_training_value = utils.constant_value(is_training)
    need_moments = is_training_value is None or is_training_value
    if need_moments:
      # Calculate the moments based on the individual batch.
      if batch_weights is None:
        if data_format == DATA_FORMAT_NCHW:
          mean, _ = nn.moments(inputs, moments_axes, keep_dims=True)
          variance,_ = nn.moments( (inputs-moving_mean)**2, moments_axes, keep_dims=True)
          mean = array_ops.reshape(mean, [-1])
          variance = array_ops.reshape(variance, [-1])
        else:
          mean, _ = nn.moments(inputs, moments_axes)
          variance, _ = nn.moments( (inputs-moving_mean)**2, moments_axes)
      else:
        if data_format == DATA_FORMAT_NCHW:
          mean, _ = nn.weighted_moments(inputs, moments_axes,
                                               batch_weights, keep_dims=True)
          variance, _ = nn.weighted_moments( (inputs-moving_mean)**2, moments_axes,
                                               batch_weights, keep_dims=True)
          mean = array_ops.reshape(mean, [-1])
          variance = array_ops.reshape(variance, [-1])
        else:
          mean, _ = nn.weighted_moments(inputs, moments_axes,
                                               batch_weights)
          variance, _ = nn.weighted_moments( (inputs-moving_mean)**2, moments_axes,
                                               batch_weights)

      moving_vars_fn = lambda: (moving_mean, moving_variance)
      if updates_collections is None:
        def _force_updates():
          """Internal function forces updates moving_vars if is_training."""
          update_moving_mean = moving_averages.assign_moving_average(
              moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
          update_moving_variance = moving_averages.assign_moving_average(
              moving_variance, variance, decay, zero_debias=False)
          with ops.control_dependencies([update_moving_mean,
                                         update_moving_variance]):
            return array_ops.identity(mean), array_ops.identity(variance)
        mean, variance = utils.smart_cond(is_training,
                                          _force_updates,
                                          moving_vars_fn)
      else:
        def _delay_updates():
          """Internal function that delay updates moving_vars if is_training."""
          update_moving_mean = moving_averages.assign_moving_average(
              moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
          update_moving_variance = moving_averages.assign_moving_average(
              moving_variance, variance, decay, zero_debias=False)
          return update_moving_mean, update_moving_variance

        update_mean, update_variance = utils.smart_cond(is_training,
                                                        _delay_updates,
                                                        moving_vars_fn)
        ops.add_to_collections(updates_collections, update_mean)
        ops.add_to_collections(updates_collections, update_variance)
        # Use computed moments during training and moving_vars otherwise.
        vars_fn = lambda: (mean, variance)
        mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn)
    else:
      mean, variance = moving_mean, moving_variance
    if data_format == DATA_FORMAT_NCHW:
      mean = array_ops.reshape(mean, params_shape_broadcast)
      variance = array_ops.reshape(variance, params_shape_broadcast)
      beta = array_ops.reshape(beta, params_shape_broadcast)
      if gamma is not None:
        gamma = array_ops.reshape(gamma, params_shape_broadcast)

    # Compute batch_normalization.
    outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma,
                                     epsilon)
    outputs.set_shape(inputs_shape)
    if activation_fn is not None:
      outputs = activation_fn(outputs)
    return utils.collect_named_outputs(outputs_collections,
                                       sc.original_name_scope, outputs)
예제 #33
0
def batch_norm(inputs,
               decay=0.999,
               center=True,
               scale=False,
               epsilon=0.001,
               activation_fn=None,
               initializers={},
               updates_collections=None,
               is_training=True,
               reuse=None,
               variables_collections=None,
               outputs_collections=None,
               trainable=False,
               scope=None):
    """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.

    "Batch Normalization: Accelerating Deep Network Training by Reducing
    Internal Covariate Shift"

    Sergey Ioffe, Christian Szegedy

  Can be used as a normalizer function for conv2d and fully_connected.

  Note: When is_training is True the moving_mean and moving_variance need to be
  updated, by default the update_ops are placed in tf.GraphKeys.UPDATE_OPS so
  they need to be added as a dependency to the train_op, example:

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    if update_ops:
      updates = tf.group(*update_ops)
      total_loss = control_flow_ops.with_dependencies([updates], total_loss)

  One can set update_collections=None to force the updates in place, but that
  can have speed penalty, specially in distributed settings.

  Args:
    inputs: a tensor with 2 or more dimensions, where the first dimension has
      `batch_size`. The normalization is over all but the last dimension.
    decay: decay for the moving average.
    center: If True, subtract `beta`. If False, `beta` is ignored.
    scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `nn.relu`), this can be
      disabled since the scaling can be done by the next layer.
    epsilon: small float added to variance to avoid dividing by zero.
    activation_fn: activation function, default set to None to skip it and
      maintain a linear activation.
    updates_collections: collections to collect the update ops for computation.
      The updates_ops need to be excuted with the train_op.
      If None, a control dependency would be added to make sure the updates are
      computed in place.
    is_training: whether or not the layer is in training mode. In training mode
      it would accumulate the statistics of the moments into `moving_mean` and
      `moving_variance` using an exponential moving average with the given
      `decay`. When it is not in training mode then it would use the values of
      the `moving_mean` and the `moving_variance`.
    reuse: whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    variables_collections: optional collections for the variables.
    outputs_collections: collections to add the outputs.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
    scope: Optional scope for `variable_scope`.

  Returns:
    A `Tensor` representing the output of the operation.

  Raises:
    ValueError: if rank or last dimension of `inputs` is undefined.
  """

    with variable_scope.variable_scope(scope,
                                       'BatchNorm', [inputs],
                                       reuse=reuse) as sc:
        inputs = ops.convert_to_tensor(inputs)
        inputs_shape = inputs.get_shape()
        inputs_rank = inputs_shape.ndims
        if inputs_rank is None:
            raise ValueError('Inputs %s has undefined rank.' % inputs.name)
        dtype = inputs.dtype.base_dtype
        axis = list(range(inputs_rank - 1))
        params_shape = inputs_shape[-1:]
        if not params_shape.is_fully_defined():
            raise ValueError('Inputs %s has undefined last dimension %s.' %
                             (inputs.name, params_shape))
        # Allocate parameters for the beta and gamma of the normalization.
        beta, gamma = None, None

        # Create moving_mean and moving_variance variables and add them to the
        # appropiate collections.
        moving_mean_initializer = initializers.get('moving_mean',
                                                   init_ops.zeros_initializer)
        moving_mean = variables.model_variable(
            'moving_mean',
            shape=params_shape,
            dtype=dtype,
            initializer=moving_mean_initializer,
            trainable=False)
        moving_variance_initializer = initializers.get(
            'moving_variance', init_ops.ones_initializer)
        moving_variance = variables.model_variable(
            'moving_variance',
            shape=params_shape,
            dtype=dtype,
            initializer=moving_variance_initializer,
            trainable=False)

        # If `is_training` doesn't have a constant value, because it is a `Tensor`,
        # a `Variable` or `Placeholder` then is_training_value will be None and
        # `needs_moments` will be true.
        is_training_value = utils.constant_value(is_training)
        need_moments = is_training_value is None or is_training_value
        if need_moments:
            # Calculate the moments based on the individual batch.
            # Use a copy of moving_mean as a shift to compute more reliable moments.
            shift = math_ops.add(moving_mean, 0)
            mean, variance = nn.moments(inputs, axis, shift=shift)
            moving_vars_fn = lambda: (moving_mean, moving_variance)
            if updates_collections is None:

                def _force_updates():
                    """Internal function forces updates moving_vars if is_training."""
                    update_moving_mean = moving_averages.assign_moving_average(
                        moving_mean, mean, decay)
                    update_moving_variance = moving_averages.assign_moving_average(
                        moving_variance, variance, decay)
                    with ops.control_dependencies(
                        [update_moving_mean, update_moving_variance]):
                        return array_ops.identity(mean), array_ops.identity(
                            variance)

                mean, variance = utils.smart_cond(is_training, _force_updates,
                                                  moving_vars_fn)
            else:

                def _delay_updates():
                    """Internal function that delay updates moving_vars if is_training."""
                    update_moving_mean = moving_averages.assign_moving_average(
                        moving_mean, mean, decay)
                    update_moving_variance = moving_averages.assign_moving_average(
                        moving_variance, variance, decay)
                    return update_moving_mean, update_moving_variance

                update_mean, update_variance = utils.smart_cond(
                    is_training, _delay_updates, moving_vars_fn)
                ops.add_to_collections(updates_collections, update_mean)
                ops.add_to_collections(updates_collections, update_variance)
                # Use computed moments during training and moving_vars otherwise.
                vars_fn = lambda: (mean, variance)
                mean, variance = utils.smart_cond(is_training, vars_fn,
                                                  moving_vars_fn)
        else:
            mean, variance = moving_mean, moving_variance
        # Compute batch_normalization.
        outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma,
                                         epsilon)
        outputs.set_shape(inputs_shape)
        if activation_fn is not None:
            outputs = activation_fn(outputs)
        return utils.collect_named_outputs(outputs_collections,
                                           sc.original_name_scope, outputs)
예제 #34
0
    def call(self, inputs, training=False):
        # First, compute the axes along which to reduce the mean / variance,
        # as well as the broadcast shape to be used for all parameters.
        input_shape = inputs.get_shape()
        ndim = len(input_shape)
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis].value

        # Determines whether broadcasting is needed.
        needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1])

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = utils.constant_value(training)

        if needs_broadcasting:
            # In this case we must explictly broadcast all parameters.
            if self.center:
                broadcast_beta = array_ops.reshape(self.beta, broadcast_shape)
            else:
                broadcast_beta = None
            if self.scale:
                broadcast_gamma = array_ops.reshape(self.gamma,
                                                    broadcast_shape)
            else:
                broadcast_gamma = None

        # Determines moments
        if training_value is not False:
            if needs_broadcasting:
                broadcast_mean, broadcast_variance = nn.moments(inputs,
                                                                reduction_axes,
                                                                keep_dims=True)
                mean = array_ops.reshape(broadcast_mean, [-1])
                variance = array_ops.reshape(broadcast_variance, [-1])
            else:
                mean, variance = nn.moments(inputs, reduction_axes)

            # Prepare updates if necessary.
            if not self.updates:
                mean_update = moving_averages.assign_moving_average(
                    self.moving_mean, mean, self.momentum, zero_debias=False)
                variance_update = moving_averages.assign_moving_average(
                    self.moving_variance,
                    variance,
                    self.momentum,
                    zero_debias=False)
                # In the future this should be refactored into a self.add_update
                # methods in order to allow for instance-based BN layer sharing
                # across unrelated input streams (e.g. like in Keras).
                self.updates.append(mean_update)
                self.updates.append(variance_update)

        # Normalize batch. We do this inside separate functions for training
        # and inference so as to avoid evaluating both branches.
        def normalize_in_test():
            if needs_broadcasting:
                broadcast_moving_mean = array_ops.reshape(
                    self.moving_mean, broadcast_shape)
                broadcast_moving_variance = array_ops.reshape(
                    self.moving_variance, broadcast_shape)
            arg_mean = broadcast_moving_mean if needs_broadcasting else self.moving_mean
            arg_variance = broadcast_moving_variance if needs_broadcasting else self.moving_variance
            arg_beta = broadcast_beta if needs_broadcasting else (
                self.beta if self.center else None)
            arg_gamma = broadcast_gamma if needs_broadcasting else (
                self.gamma if self.scale else None)
            if self.quantizer is None:
                return nn.batch_normalization(inputs, arg_mean, arg_variance,
                                              arg_beta, arg_gamma,
                                              self.epsilon)
            else:
                return qbatch_normalization(inputs, arg_mean, arg_variance,
                                            arg_beta, arg_gamma, self.epsilon,
                                            self.quantizer)

        def normalize_in_training():
            arg_mean = broadcast_mean if needs_broadcasting else mean
            arg_variance = broadcast_variance if needs_broadcasting else variance
            arg_beta = broadcast_beta if needs_broadcasting else (
                self.beta if self.center else None)
            arg_gamma = broadcast_gamma if needs_broadcasting else (
                self.gamma if self.scale else None)
            if self.quantizer is None:
                return nn.batch_normalization(inputs, arg_mean, arg_variance,
                                              arg_beta, arg_gamma,
                                              self.epsilon)
            else:
                return qbatch_normalization(inputs, arg_mean, arg_variance,
                                            arg_beta, arg_gamma, self.epsilon,
                                            self.quantizer)

        return utils.smart_cond(training, normalize_in_training,
                                normalize_in_test)
예제 #35
0
    def _get_elbo_label(self, inp, tgt, msk, label, args):
        """ Build encoder and decoders """
        xlen = tf.to_int32(tf.reduce_sum(msk, axis=1))
        enc_state = self._create_encoder(tgt,
                                         seqlen=xlen,
                                         scope_name='enc',
                                         args=args)
        enc_state = tf.nn.dropout(enc_state, self.keep_prob_plh)
        enc_state = tf.contrib.layers.fully_connected(
            enc_state,
            num_outputs=args.num_units,
            activation_fn=None,
            scope='x_to_a')
        enc_state = tf.contrib.layers.batch_norm(
            enc_state,
            center=True,
            scale=True,
            is_training=self.is_training_plh,
            scope='bn_a')
        enc_state = tf.tanh(enc_state)
        enc_state = tf.nn.dropout(enc_state, self.keep_prob_plh)

        label_oh = tf.gather(tf.eye(args.num_classes), label)
        with tf.variable_scope('latent'):
            y_enc_in = tf.contrib.layers.fully_connected(label_oh,
                                                         args.dim_z,
                                                         scope='y_enc_in')
            y_enc_in = tf.nn.dropout(y_enc_in, self.keep_prob_plh)
            pst_in = tf.concat([y_enc_in, enc_state], axis=1)
            pst_in = tf.contrib.layers.fully_connected(pst_in,
                                                       args.num_units,
                                                       None,
                                                       scope='pst_in_dense')
            pst_in = tf.contrib.layers.batch_norm(
                pst_in,
                center=True,
                scale=True,
                is_training=self.is_training_plh,
                scope='pst_in_bn')
            pst_in = tf.tanh(pst_in)
            pst_in = tf.nn.dropout(pst_in, self.keep_prob_plh)
            mu_pst = tf.contrib.layers.fully_connected(pst_in,
                                                       args.dim_z,
                                                       tf.nn.tanh,
                                                       scope='mu_posterior')
            logvar_pst = tf.contrib.layers.fully_connected(
                pst_in, args.dim_z, tf.nn.tanh, scope='logvar_posterior')
            mu_pri = tf.zeros_like(mu_pst)
            logvar_pri = tf.ones_like(logvar_pst)
            dist_pri = tf.contrib.distributions.Normal(
                mu=mu_pri, sigma=tf.exp(logvar_pri))
            dist_pst = tf.contrib.distributions.Normal(
                mu=mu_pst, sigma=tf.exp(logvar_pst))
            kl_loss = tf.contrib.distributions.kl(dist_pst, dist_pri)
            kl_loss = tf.reduce_sum(kl_loss, axis=1)

        with st.value_type(st.SampleValue(stop_gradient=False)):
            z_st_pri = st.StochasticTensor(dist_pri, name='z_pri')
            z_st_pst = st.StochasticTensor(dist_pst, name='z_pst')
            z = smart_cond(self.is_training_plh, lambda: z_st_pst,
                           lambda: z_st_pri)

        z_ext = tf.contrib.layers.fully_connected(tf.reshape(
            z, [-1, args.dim_z]),
                                                  args.num_units,
                                                  scope='extend_z')
        z_ext = tf.nn.dropout(z_ext, self.keep_prob_plh)
        yz = tf.concat([z_ext, label_oh], axis=1)
        yz = tf.contrib.layers.fully_connected(yz,
                                               args.num_units,
                                               None,
                                               scope='yz_dense')
        yz = tf.contrib.layers.batch_norm(yz,
                                          center=True,
                                          scale=True,
                                          is_training=self.is_training_plh,
                                          scope='yz_bn')
        yz = tf.tanh(yz)
        yz = tf.nn.dropout(yz, self.keep_prob_plh)
        xlen = tf.to_int32(tf.reduce_sum(msk, axis=1))
        outs, proj, dec_func, cell = self._create_decoder(
            inp,
            mask=msk,
            label_oh=label_oh,
            init_state=yz,  #tf.contrib.rnn.LSTMStateTuple(yz, yz),
            scope_name='dec',
            args=args)
        outs = tf.nn.dropout(outs, self.keep_prob_plh)

        # build loss layers
        recons_loss = self._create_softmax_layer(proj=proj,
                                                 dec_outs=outs,
                                                 targets=tgt,
                                                 weights=msk,
                                                 scope_name='loss',
                                                 args=args)
        recons_loss = recons_loss * msk

        return recons_loss, kl_loss
예제 #36
0
파일: batch_norm.py 프로젝트: skang29/GANs
    def batch_norm_backbone(inputs,
                            decay=0.999,
                            center=True,
                            scale=False,
                            epsilon=0.001,
                            activation_fn=None,
                            param_initializers=None,
                            param_regularizers=None,
                            updates_collections=ops.GraphKeys.UPDATE_OPS,
                            is_training=True,
                            reuse=None,
                            variables_collections=None,
                            outputs_collections=None,
                            trainable=True,
                            batch_weights=None,
                            fused=None,
                            data_format=DATA_FORMAT_NHWC,
                            zero_debias_moving_mean=False,
                            scope=None,
                            renorm=False,
                            renorm_clipping=None,
                            renorm_decay=0.99,
                            adjustment=None,
                            tower_config=None):

        """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
          "Batch Normalization: Accelerating Deep Network Training by Reducing
          Internal Covariate Shift"
          Sergey Ioffe, Christian Szegedy
        Can be used as a normalizer function for conv2d and fully_connected. The
        normalization is over all but the last dimension if `data_format` is `NHWC`
        and all but the second dimension if `data_format` is `NCHW`.  In case of a 2D
        tensor this corresponds to the batch dimension, while in case of a 4D tensor
        this
        corresponds to the batch and space dimensions.
        Note: when training, the moving_mean and moving_variance need to be updated.
        By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they
        need to be added as a dependency to the `train_op`. For example:
        ```python
          update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
          with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss)
        ```
        One can set updates_collections=None to force the updates in place, but that
        can have a speed penalty, especially in distributed settings.
        Args:
          inputs: A tensor with 2 or more dimensions, where the first dimension has
            `batch_size`. The normalization is over all but the last dimension if
            `data_format` is `NHWC` and the second dimension if `data_format` is
            `NCHW`.
          decay: Decay for the moving average. Reasonable values for `decay` are close
            to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc.
            Lower `decay` value (recommend trying `decay`=0.9) if model experiences
            reasonably good training performance but poor validation and/or test
            performance. Try zero_debias_moving_mean=True for improved stability.
          center: If True, add offset of `beta` to normalized tensor. If False, `beta`
            is ignored.
          scale: If True, multiply by `gamma`. If False, `gamma` is
            not used. When the next layer is linear (also e.g. `nn.relu`), this can be
            disabled since the scaling can be done by the next layer.
          epsilon: Small float added to variance to avoid dividing by zero.
          activation_fn: Activation function, default set to None to skip it and
            maintain a linear activation.
          param_initializers: Optional initializers for beta, gamma, moving mean and
            moving variance.
          param_regularizers: Optional regularizer for beta and gamma.
          updates_collections: Collections to collect the update ops for computation.
            The updates_ops need to be executed with the train_op.
            If None, a control dependency would be added to make sure the updates are
            computed in place.
          is_training: Whether or not the layer is in training mode. In training mode
            it would accumulate the statistics of the moments into `moving_mean` and
            `moving_variance` using an exponential moving average with the given
            `decay`. When it is not in training mode then it would use the values of
            the `moving_mean` and the `moving_variance`.
          reuse: Whether or not the layer and its variables should be reused. To be
            able to reuse the layer scope must be given.
          variables_collections: Optional collections for the variables.
          outputs_collections: Collections to add the outputs.
          trainable: If `True` also add variables to the graph collection
            `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
          batch_weights: An optional tensor of shape `[batch_size]`,
            containing a frequency weight for each batch item. If present,
            then the batch normalization uses weighted mean and
            variance. (This can be used to correct for bias in training
            example selection.)
          fused: if `None` or `True`, use a faster, fused implementation if possible.
            If `False`, use the system recommended implementation.
          data_format: A string. `NHWC` (default) and `NCHW` are supported.
          zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new
            pair of variables 'moving_mean/biased' and 'moving_mean/local_step'.
          scope: Optional scope for `variable_scope`.
          renorm: Whether to use Batch Renormalization
            (https://arxiv.org/abs/1702.03275). This adds extra variables during
            training. The inference is the same for either value of this parameter.
          renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
            scalar `Tensors` used to clip the renorm correction. The correction
            `(r, d)` is used as `corrected_value = normalized_value * r + d`, with
            `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
            dmax are set to inf, 0, inf, respectively.
          renorm_decay: Momentum used to update the moving means and standard
            deviations with renorm. Unlike `momentum`, this affects training
            and should be neither too small (which would add noise) nor too large
            (which would give stale estimates). Note that `decay` is still applied
            to get the means and variances for inference.
          adjustment: A function taking the `Tensor` containing the (dynamic) shape of
            the input tensor and returning a pair (scale, bias) to apply to the
            normalized values (before gamma and beta), only during training. For
            example,
              `adjustment = lambda shape: (
                tf.random_uniform(shape[-1:], 0.93, 1.07),
                tf.random_uniform(shape[-1:], -0.1, 0.1))`
            will scale the normalized value by up to 7% up or down, then shift the
            result by up to 0.1 (with independent scaling and bias for each feature
            but shared across all examples), and finally apply gamma and/or beta. If
            `None`, no adjustment is applied.
        Returns:
          A `Tensor` representing the output of the operation.
        Raises:
          ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
          ValueError: If the rank of `inputs` is undefined.
          ValueError: If rank or channels dimension of `inputs` is undefined.
        """
        # if fused is None:
        #     fused = True

        # Only use _fused_batch_norm if all of the following three
        # conditions are true:
        # (1) fused is set True;
        # (2) it is possible to use (currently it doesn't support batch weights,
        #   renorm, and the case when rank is neither 2 nor 4);
        # (3) it is used with zero_debias_moving_mean, or an input shape of rank 2,
        #   or non-default updates_collections (not implemented in
        #   normalization_layers.BatchNormalization yet); otherwise use the fused
        #   implementation in normalization_layers.BatchNormalization.
        # inputs = ops.convert_to_tensor(inputs)
        # rank = inputs.get_shape().ndims
        # possible_to_fuse = (
        #     batch_weights is None and not renorm and rank in [2, 4] and
        #     adjustment is None)
        # if fused and possible_to_fuse and (
        #                 zero_debias_moving_mean or rank == 2 or
        #                 updates_collections is not ops.GraphKeys.UPDATE_OPS):
        #     return _fused_batch_norm(
        #         inputs,
        #         decay=decay,
        #         center=center,
        #         scale=scale,
        #         epsilon=epsilon,
        #         activation_fn=activation_fn,
        #         param_initializers=param_initializers,
        #         param_regularizers=param_regularizers,
        #         updates_collections=updates_collections,
        #         is_training=is_training,
        #         reuse=reuse,
        #         variables_collections=variables_collections,
        #         outputs_collections=outputs_collections,
        #         trainable=trainable,
        #         data_format=data_format,
        #         zero_debias_moving_mean=zero_debias_moving_mean,
        #         scope=scope)

        if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
            raise ValueError('data_format has to be either NCHW or NHWC.')

        layer_variable_getter = _build_variable_getter()
        with variable_scope.variable_scope(
                scope,
                'BatchNorm', [inputs],
                reuse=reuse,
                custom_getter=layer_variable_getter) as sc:
            inputs = ops.convert_to_tensor(inputs)

            # # Determine whether we can use the core layer class.
            # if (batch_weights is None and
            #             updates_collections is ops.GraphKeys.UPDATE_OPS and
            #         not zero_debias_moving_mean):
            #     print("F**K !!!!")
            #     # Use the core layer class.
            #     axis = 1 if data_format == DATA_FORMAT_NCHW else -1
            #     if not param_initializers:
            #         param_initializers = {}
            #     beta_initializer = param_initializers.get('beta',
            #                                               init_ops.zeros_initializer())
            #     gamma_initializer = param_initializers.get('gamma',
            #                                                init_ops.ones_initializer())
            #     moving_mean_initializer = param_initializers.get(
            #         'moving_mean', init_ops.zeros_initializer())
            #     moving_variance_initializer = param_initializers.get(
            #         'moving_variance', init_ops.ones_initializer())
            #     if not param_regularizers:
            #         param_regularizers = {}
            #     beta_regularizer = param_regularizers.get('beta')
            #     gamma_regularizer = param_regularizers.get('gamma')
            #     layer = normalization_layers.BatchNormalization(
            #         axis=axis,
            #         momentum=decay,
            #         epsilon=epsilon,
            #         center=center,
            #         scale=scale,
            #         beta_initializer=beta_initializer,
            #         gamma_initializer=gamma_initializer,
            #         moving_mean_initializer=moving_mean_initializer,
            #         moving_variance_initializer=moving_variance_initializer,
            #         beta_regularizer=beta_regularizer,
            #         gamma_regularizer=gamma_regularizer,
            #         trainable=trainable,
            #         renorm=renorm,
            #         renorm_clipping=renorm_clipping,
            #         renorm_momentum=renorm_decay,
            #         adjustment=adjustment,
            #         name=sc.name,
            #         _scope=sc,
            #         _reuse=reuse,
            #         fused=fused)
            #     outputs = layer.apply(inputs, training=is_training)
            #
            #     # Add variables to collections.
            #     _add_variable_to_collections(layer.moving_mean, variables_collections,
            #                                  'moving_mean')
            #     _add_variable_to_collections(layer.moving_variance, variables_collections,
            #                                  'moving_variance')
            #     if layer.beta is not None:
            #         _add_variable_to_collections(layer.beta, variables_collections, 'beta')
            #     if layer.gamma is not None:
            #         _add_variable_to_collections(layer.gamma, variables_collections,
            #                                      'gamma')
            #
            #     if activation_fn is not None:
            #         outputs = activation_fn(outputs)
            #     return utils.collect_named_outputs(outputs_collections, sc.name, outputs)

            # Not supported by layer class: batch_weights argument,
            # and custom updates_collections. In that case, use the legacy BN
            # implementation.
            # Custom updates collections are not supported because the update logic
            # is different in this case, in particular w.r.t. "forced updates" and
            # update op reuse.
            if renorm:
                raise ValueError('renorm is not supported with batch_weights, '
                                 'updates_collections or zero_debias_moving_mean')
            inputs_shape = inputs.get_shape()
            inputs_rank = inputs_shape.ndims
            if inputs_rank is None:
                raise ValueError('Inputs %s has undefined rank.' % inputs.name)
            dtype = inputs.dtype.base_dtype
            if batch_weights is not None:
                batch_weights = ops.convert_to_tensor(batch_weights)
                inputs_shape[0:1].assert_is_compatible_with(batch_weights.get_shape())
                # Reshape batch weight values so they broadcast across inputs.
                nshape = [-1] + [1 for _ in range(inputs_rank - 1)]
                batch_weights = array_ops.reshape(batch_weights, nshape)

            if data_format == DATA_FORMAT_NCHW:
                moments_axes = [0] + list(range(2, inputs_rank))
                params_shape = inputs_shape[1:2]
                # For NCHW format, rather than relying on implicit broadcasting, we
                # explicitly reshape the params to params_shape_broadcast when computing
                # the moments and the batch normalization.
                params_shape_broadcast = list(
                    [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)])
            else:
                moments_axes = list(range(inputs_rank - 1))
                params_shape = inputs_shape[-1:]
                params_shape_broadcast = None
            if not params_shape.is_fully_defined():
                raise ValueError('Inputs %s has undefined channels dimension %s.' %
                                 (inputs.name, params_shape))

            # Allocate parameters for the beta and gamma of the normalization.
            beta, gamma = None, None
            if not param_initializers:
                param_initializers = {}
            if center:
                beta_collections = utils.get_variable_collections(variables_collections,
                                                                  'beta')
                beta_initializer = param_initializers.get('beta',
                                                          init_ops.zeros_initializer())
                beta = variables.model_variable(
                    'beta',
                    shape=params_shape,
                    dtype=dtype,
                    initializer=beta_initializer,
                    collections=beta_collections,
                    trainable=trainable)
            if scale:
                gamma_collections = utils.get_variable_collections(
                    variables_collections, 'gamma')
                gamma_initializer = param_initializers.get('gamma',
                                                           init_ops.ones_initializer())
                gamma = variables.model_variable(
                    'gamma',
                    shape=params_shape,
                    dtype=dtype,
                    initializer=gamma_initializer,
                    collections=gamma_collections,
                    trainable=trainable)

            # Create moving_mean and moving_variance variables and add them to the
            # appropriate collections. We disable variable partitioning while creating
            # them, because assign_moving_average is not yet supported for partitioned
            # variables (this needs to be handled carefully, as it may break
            # the checkpoint backward compatibility).
            with variable_scope.variable_scope(
                    variable_scope.get_variable_scope()) as local_scope:
                local_scope.set_partitioner(None)
                moving_mean_collections = utils.get_variable_collections(
                    variables_collections, 'moving_mean')
                moving_mean_initializer = param_initializers.get(
                    'moving_mean', init_ops.zeros_initializer())
                moving_mean = variables.model_variable(
                    'moving_mean',
                    shape=params_shape,
                    dtype=dtype,
                    initializer=moving_mean_initializer,
                    trainable=False,
                    collections=moving_mean_collections)
                moving_variance_collections = utils.get_variable_collections(
                    variables_collections, 'moving_variance')
                moving_variance_initializer = param_initializers.get(
                    'moving_variance', init_ops.ones_initializer())
                moving_variance = variables.model_variable(
                    'moving_variance',
                    shape=params_shape,
                    dtype=dtype,
                    initializer=moving_variance_initializer,
                    trainable=False,
                    collections=moving_variance_collections)

            # If `is_training` doesn't have a constant value, because it is a `Tensor`,
            # a `Variable` or `Placeholder` then is_training_value will be None and
            # `needs_moments` will be true.
            is_training_value = utils.constant_value(is_training)
            need_moments = is_training_value is None or is_training_value
            if need_moments:
                # Calculate the moments based on the individual batch.
                if batch_weights is None:
                    if data_format == DATA_FORMAT_NCHW:
                        mean, variance = moments(inputs, moments_axes, tower_config=tower_config, keep_dims=True)
                        mean = array_ops.reshape(mean, [-1])
                        variance = array_ops.reshape(variance, [-1])
                    else:
                        mean, variance = moments(inputs, moments_axes, tower_config=tower_config)
                else:
                    if data_format == DATA_FORMAT_NCHW:
                        mean, variance = weighted_moments(
                            inputs, moments_axes, batch_weights, tower_config, keep_dims=True)
                        mean = array_ops.reshape(mean, [-1])
                        variance = array_ops.reshape(variance, [-1])
                    else:
                        mean, variance = weighted_moments(inputs, moments_axes,
                                                             batch_weights, tower_config=tower_config)

                moving_vars_fn = lambda: (moving_mean, moving_variance)
                if updates_collections is None:

                    def _force_updates():
                        """Internal function forces updates moving_vars if is_training."""
                        update_moving_mean = moving_averages.assign_moving_average(
                            moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
                        update_moving_variance = moving_averages.assign_moving_average(
                            moving_variance, variance, decay, zero_debias=False)
                        with ops.control_dependencies(
                                [update_moving_mean, update_moving_variance]):
                            return array_ops.identity(mean), array_ops.identity(variance)

                    mean, variance = utils.smart_cond(is_training, _force_updates,
                                                      moving_vars_fn)
                else:

                    def _delay_updates():
                        """Internal function that delay updates moving_vars if is_training."""
                        update_moving_mean = moving_averages.assign_moving_average(
                            moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
                        update_moving_variance = moving_averages.assign_moving_average(
                            moving_variance, variance, decay, zero_debias=False)
                        return update_moving_mean, update_moving_variance

                    update_mean, update_variance = utils.smart_cond(
                        is_training, _delay_updates, moving_vars_fn)
                    ops.add_to_collections(updates_collections, update_mean)
                    ops.add_to_collections(updates_collections, update_variance)
                    # Use computed moments during training and moving_vars otherwise.
                    vars_fn = lambda: (mean, variance)
                    mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn)
            else:
                mean, variance = moving_mean, moving_variance
            if data_format == DATA_FORMAT_NCHW:
                mean = array_ops.reshape(mean, params_shape_broadcast)
                variance = array_ops.reshape(variance, params_shape_broadcast)
                if beta is not None:
                    beta = array_ops.reshape(beta, params_shape_broadcast)
                if gamma is not None:
                    gamma = array_ops.reshape(gamma, params_shape_broadcast)

            # Compute batch_normalization.
            outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma,
                                             epsilon)
            outputs.set_shape(inputs_shape)
            if activation_fn is not None:
                outputs = activation_fn(outputs)
            return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
예제 #37
0
  def _build_statistics_second_moment(self, input_batch,
                                      reduction_indices, use_batch_stats):
    """Builds the statistics part of the graph when using moving second moment.

    Args:
      input_batch: Input batch Tensor.
      reduction_indices: Indices of `input_batch` to reduce over.
      use_batch_stats: Boolean to indicate if batch statistics should be
        calculated, otherwise moving averages are returned.

    Returns:
      Tuple of (mean, variance, second_moment).
    """
    # Set up our moving statistics. When connecting in parallel, this is shared.
    self._moving_mean = tf.get_variable(
        "moving_mean",
        shape=self._mean_shape,
        collections=[tf.GraphKeys.MOVING_AVERAGE_VARIABLES,
                     tf.GraphKeys.GLOBAL_VARIABLES],
        initializer=tf.zeros_initializer(),
        trainable=False)

    self._moving_second_moment = tf.get_variable(
        "moving_second_moment",
        shape=self._mean_shape,
        collections=[tf.GraphKeys.MOVING_AVERAGE_VARIABLES,
                     tf.GraphKeys.GLOBAL_VARIABLES],
        initializer=tf.ones_initializer(),
        trainable=False)

    self._moving_variance = tf.subtract(self._moving_second_moment,
                                        tf.square(self._moving_mean),
                                        name="moving_variance")

    def build_batch_stats():
      """Builds the batch statistics calculation ops."""

      # Copy for better stability.
      # We use the moving mean as an estimate of the mean in order to perform
      # a more numerically stable calculation of the batch mean.
      shift = tf.add(self._moving_mean, 0)
      counts, shifted_sum_x, shifted_sum_x2, _ = tf.nn.sufficient_statistics(
          input_batch,
          reduction_indices,
          keep_dims=True,
          shift=shift,
          name="batch_norm_ss")

      mean, variance = tf.nn.normalize_moments(counts,
                                               shifted_sum_x,
                                               shifted_sum_x2,
                                               shift,
                                               name="normalize_moments")
      second_moment = variance + tf.square(mean)

      return mean, variance, second_moment

    def build_moving_stats():
      return (
          tf.identity(self._moving_mean),
          tf.identity(self._moving_variance),
          tf.identity(self._moving_second_moment),
      )

    mean, variance, second_moment = utils.smart_cond(
        use_batch_stats,
        build_batch_stats,
        build_moving_stats,
    )

    return mean, variance, second_moment
예제 #38
0
def fused_batch_norm(
        inputs,
        renorm=False,
        RMAX=None,
        DMAX=None,
        decay=0.999,
        center=True,
        scale=False,
        epsilon=0.001,
        activation_fn=None,
        param_initializers=None,
        is_training=True,
        reuse=None,
        variables_collections=None,
        outputs_collections=None,
        trainable=True,
        data_format=DATA_FORMAT_NHWC,
        zero_debias_moving_mean=False,
        scope=None):
    """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.

        "Batch Normalization: Accelerating Deep Network Training by Reducing
        Internal Covariate Shift"

        Sergey Ioffe, Christian Szegedy

    Can be used as a normalizer function for conv2d and fully_connected.

    Note: When is_training is True the moving_mean and moving_variance need to be
    updated, by default the update_ops are placed in `tf.GraphKeys.UPDATE_OPS` so
    they need to be added as a dependency to the `train_op`, example:

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        if update_ops:
        updates = tf.group(*update_ops)
        total_loss = control_flow_ops.with_dependencies([updates], total_loss)

    Args:
        inputs: a tensor with 2 or more dimensions, where the first dimension has
        `batch_size`. The normalization is over all but the last dimension if
        `data_format` is `NHWC` and the second dimension if `data_format` is
        `NCHW`.
        decay: decay for the moving average. Reasonable values for `decay` are close
        to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc.
        Lower `decay` value (recommend trying `decay`=0.9) if model experiences
        reasonably good training performance but poor validation and/or test
        performance.
        center: If True, add offset of `beta` to normalized tensor.  If False,
        `beta` is ignored.
        scale: If True, multiply by `gamma`. If False, `gamma` is
        not used. When the next layer is linear (also e.g. `nn.relu`), this can be
        disabled since the scaling can be done by the next layer.
        epsilon: small float added to variance to avoid dividing by zero.
        activation_fn: activation function, default set to None to skip it and
        maintain a linear activation.
        param_initializers: optional initializers for beta, gamma, moving mean and
        moving variance.
        updates_collections: collections to collect the update ops for computation.
        The updates_ops need to be executed with the train_op.
        If None, a control dependency would be added to make sure the updates are
        computed in place.
        is_training: whether or not the layer is in training mode. In training mode
        it would accumulate the statistics of the moments into `moving_mean` and
        `moving_variance` using an exponential moving average with the given
        `decay`. When it is not in training mode then it would use the values of
        the `moving_mean` and the `moving_variance`.
        reuse: whether or not the layer and its variables should be reused. To be
        able to reuse the layer scope must be given.
        variables_collections: optional collections for the variables.
        outputs_collections: collections to add the outputs.
        trainable: If `True` also add variables to the graph collection
        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
        data_format: A string. `NHWC` (default) and `NCHW` are supported.
        zero_debias_moving_mean: Use zero_debias for moving_mean.
        scope: Optional scope for `variable_scope`.

    Returns:
        A `Tensor` representing the output of the operation.

    Raises:
        ValueError: if `data_format` is neither `NHWC` nor `NCHW`.
        ValueError: if the rank of `inputs` is undefined.
        ValueError: if the rank of `inputs` is neither 2 or 4.
        ValueError: if rank or `C` dimension of `inputs` is undefined.
    """
    if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
        raise ValueError('data_format has to be either NCHW or NHWC.')
    with tf.variable_scope(
            scope, 'BatchNorm', [inputs], reuse=reuse) as sc:
        inputs = ops.convert_to_tensor(inputs)
        original_shape = inputs.get_shape()
        original_rank = original_shape.ndims
        if original_rank is None:
            raise ValueError('Inputs %s has undefined rank' % inputs.name)
        elif original_rank not in [2, 4]:
            raise ValueError('Inputs %s has unsupported rank.'
                            ' Expected 2 or 4 but got %d' % (
                                inputs.name, original_rank))
        if original_rank == 2:
            channels = inputs.get_shape()[-1].value
            if channels is None:
                raise ValueError('`C` dimension must be known but is None')
            new_shape = [-1, 1, 1, channels]
            if data_format == DATA_FORMAT_NCHW:
                new_shape = [-1, channels, 1, 1]
            inputs = array_ops.reshape(inputs, new_shape)
        inputs_shape = inputs.get_shape()
        dtype = inputs.dtype.base_dtype
        if data_format == DATA_FORMAT_NHWC:
            params_shape = inputs_shape[-1:]
        else:
            params_shape = inputs_shape[1:2]
        if not params_shape.is_fully_defined():
            raise ValueError('Inputs %s has undefined `C` dimension %s.' %
                            (inputs.name, params_shape))

        if not param_initializers:
            param_initializers = {}
        # Allocate parameters for the beta and gamma of the normalization.
        trainable_beta = trainable and center
        if trainable_beta:
            beta_collections = utils.get_variable_collections(variables_collections,
                                                            'beta')
            beta_initializer = param_initializers.get('beta',
                                                    init_ops.zeros_initializer())
            real_beta = variables.model_variable(
                    'beta',
                    shape=params_shape,
                    dtype=dtype,
                    initializer=beta_initializer,
                    collections=beta_collections,
                    trainable=trainable_beta)
            beta = tf.zeros(params_shape, name='fakebeta')
        else:
            real_beta = tf.zeros(params_shape, name='beta')
            beta = tf.zeros(params_shape, name='fakebeta')
        trainable_gamma = trainable and scale
        if trainable_gamma:
            gamma_collections = utils.get_variable_collections(variables_collections,
                                                            'gamma')
            gamma_initializer = param_initializers.get('gamma',
                                                    init_ops.ones_initializer())
            gamma = variables.model_variable(
                    'gamma',
                    shape=params_shape,
                    dtype=dtype,
                    initializer=gamma_initializer,
                    collections=gamma_collections,
                    trainable=trainable_gamma)
        else:
            gamma = tf.ones(params_shape, name='gamma')

        # Create moving_mean and moving_variance variables and add them to the
        # appropiate collections.
        moving_mean_collections = utils.get_variable_collections(
                variables_collections, 'moving_mean')
        moving_mean_initializer = param_initializers.get(
                'moving_mean', init_ops.zeros_initializer())
        moving_mean = variables.model_variable(
                'moving_mean',
                shape=params_shape,
                dtype=dtype,
                initializer=moving_mean_initializer,
                trainable=False,
                collections=moving_mean_collections)
        moving_variance_collections = utils.get_variable_collections(
                variables_collections, 'moving_variance')
        moving_variance_initializer = param_initializers.get(
                'moving_variance', init_ops.ones_initializer())
        moving_variance = variables.model_variable(
                'moving_variance',
                shape=params_shape,
                dtype=dtype,
                initializer=moving_variance_initializer,
                trainable=False,
                collections=moving_variance_collections)

        def _fused_batch_norm_training():
            outputs, mean, variance = nn.fused_batch_norm(
                    inputs, gamma, beta, epsilon=epsilon, data_format=data_format)
            if renorm:
                moving_inv = math_ops.rsqrt(moving_variance + epsilon)
                r = tf.stop_gradient(tf.clip_by_value(tf.sqrt(variance + epsilon) * moving_inv,
                                                        1/RMAX,
                                                        RMAX))
                d = tf.stop_gradient(tf.clip_by_value((mean - moving_mean) * moving_inv,
                                                        -DMAX,
                                                        DMAX))
                outputs = outputs * r + d
            return outputs, mean, variance
        def _fused_batch_norm_inference():
            return nn.fused_batch_norm(
                    inputs,
                    gamma,
                    beta,
                    mean=moving_mean,
                    variance=moving_variance,
                    epsilon=epsilon,
                    is_training=False,
                    data_format=data_format)
        outputs, mean, variance = utils.smart_cond(is_training,
                                                _fused_batch_norm_training,
                                                _fused_batch_norm_inference)
        outputs = tf.nn.bias_add(outputs, real_beta)

        # If `is_training` doesn't have a constant value, because it is a `Tensor`,
        # a `Variable` or `Placeholder` then is_training_value will be None and
        # `need_updates` will be true.
        is_training_value = utils.constant_value(is_training)
        need_updates = is_training_value is None or is_training_value
        if need_updates:
            moving_vars_fn = lambda: (moving_mean, moving_variance)
            def _delay_updates():
                """Internal function that delay updates moving_vars if is_training."""
                update_moving_mean = moving_averages.assign_moving_average(
                        moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
                update_moving_variance = moving_averages.assign_moving_average(
                        moving_variance, variance, decay, zero_debias=False)
                return update_moving_mean, update_moving_variance
            update_mean, update_variance = utils.smart_cond(is_training,
                                                            _delay_updates,
                                                            moving_vars_fn)
            ops.add_to_collections(ops.GraphKeys.UPDATE_OPS, update_mean)
            ops.add_to_collections(ops.GraphKeys.UPDATE_OPS, update_variance)

        outputs.set_shape(inputs_shape)
        if original_shape.ndims == 2:
            outputs = array_ops.reshape(outputs, original_shape)
        if activation_fn is not None:
            outputs = activation_fn(outputs)
        return utils.collect_named_outputs(outputs_collections,
                                        sc.original_name_scope, outputs)
예제 #39
0
def spectral_normalization(weights,
                           num_iterations=1,
                           epsilon=1e-12,
                           u_initializer=tf.random_normal_initializer(),
                           updates_collections=tf.GraphKeys.UPDATE_OPS,
                           is_training=True,
                           reuse=None,
                           variables_collections=None,
                           outputs_collections=None,
                           scope=None):
    with tf.variable_scope(scope, 'SpectralNorm', [weights],
                           reuse=reuse) as sc:
        weights = tf.convert_to_tensor(weights)

        dtype = weights.dtype.base_dtype

        w_t = tf.reshape(weights, [-1, weights.shape.as_list()[-1]])
        w = tf.transpose(w_t)
        m, n = w.shape.as_list()

        u_collections = utils.get_variable_collections(variables_collections,
                                                       'u')
        u = tf.get_variable(
            "u",
            shape=[m, 1],
            dtype=dtype,
            initializer=u_initializer,
            trainable=False,
            collections=u_collections,
        )
        sigma_collections = utils.get_variable_collections(
            variables_collections, 'sigma')
        sigma = tf.get_variable('sigma',
                                shape=[],
                                dtype=dtype,
                                initializer=tf.zeros_initializer(),
                                trainable=False,
                                collections=sigma_collections)

        def _power_iteration(i, u, v):
            v_ = tf.nn.l2_normalize(tf.matmul(w_t, u), epsilon=epsilon)
            u_ = tf.nn.l2_normalize(tf.matmul(w, v_), epsilon=epsilon)
            return i + 1, u_, v_

        _, u_, v_ = tf.while_loop(cond=lambda i, _1, _2: i < num_iterations,
                                  body=_power_iteration,
                                  loop_vars=[
                                      tf.constant(0), u,
                                      tf.zeros(shape=[n, 1], dtype=tf.float32)
                                  ])
        u_ = tf.stop_gradient(u_)
        v_ = tf.stop_gradient(v_)
        sigma_ = tf.matmul(tf.transpose(u_), tf.matmul(w, v_))[0, 0]

        update_u = u.assign(u_)
        update_sigma = sigma.assign(sigma_)
        if updates_collections is None:

            def _force_update():
                with tf.control_dependencies([update_u, update_sigma]):
                    return tf.identity(sigma_)

            sigma_ = utils.smart_cond(is_training, _force_update,
                                      lambda: sigma)
            weights_sn = weights / sigma_
        else:
            sigma_ = utils.smart_cond(is_training, lambda: sigma_,
                                      lambda: sigma)
            weights_sn = weights / sigma_
            tf.add_to_collections(updates_collections, update_u)
            tf.add_to_collections(updates_collections, update_sigma)

        return utils.collect_named_outputs(outputs_collections, sc.name,
                                           weights_sn)
예제 #40
0
    def _build_statistics(self, input_batch, axis, use_batch_stats, dtype):
        """Builds the statistics part of the graph when using moving variance.

    Args:
      input_batch: Input batch Tensor.
      axis: Indices of `input_batch` to reduce over.
      use_batch_stats: Boolean to indicate if batch statistics should be
        calculated, otherwise moving averages are returned.
      dtype: TensorFlow datatype to use for the moving mean and variance.

    Returns:
      Tuple of (mean, variance).
    """
        # Set up our moving statistics. When connecting in parallel, this is shared.
        if self.MOVING_MEAN not in self._initializers:
            self._initializers[self.MOVING_MEAN] = create_mean_initializer()
        self._moving_mean = tf.get_variable(
            "moving_mean",
            dtype=dtype,
            shape=self._mean_shape,
            collections=[
                tf.GraphKeys.MOVING_AVERAGE_VARIABLES,
                tf.GraphKeys.GLOBAL_VARIABLES,
            ],
            initializer=self._initializers[self.MOVING_MEAN],
            trainable=False)

        if self.MOVING_VARIANCE not in self._initializers:
            self._initializers[
                self.MOVING_VARIANCE] = create_variance_initializer()
        self._moving_variance = tf.get_variable(
            "moving_variance",
            dtype=dtype,
            shape=self._mean_shape,
            collections=[
                tf.GraphKeys.MOVING_AVERAGE_VARIABLES,
                tf.GraphKeys.GLOBAL_VARIABLES,
            ],
            initializer=self._initializers[self.MOVING_VARIANCE],
            trainable=False)

        def build_batch_stats():
            """Builds the batch statistics calculation ops."""
            mean, variance = tf.nn.moments(input_batch,
                                           axis,
                                           keep_dims=True,
                                           name="normalize_moments")

            return mean, variance

        def build_moving_stats():
            return (
                tf.identity(self._moving_mean),
                tf.identity(self._moving_variance),
            )

        mean, variance = utils.smart_cond(
            use_batch_stats,
            build_batch_stats,
            build_moving_stats,
        )

        return mean, variance