def testShakeShake(self):
     x = np.random.rand(5, 7)
     with self.test_session() as session:
         x = tf.constant(x, dtype=tf.float32)
         y = common_layers.shakeshake([x, x, x, x, x])
         session.run(tf.global_variables_initializer())
         inp, res = session.run([x, y])
     self.assertAllClose(res, inp)
    def model_fn_body_sharded(self, sharded_features, train):
        dp = self._data_parallelism
        dp._reuse = False  # pylint:disable=protected-access
        hparams = self._hparams
        blocks = [
            identity_module, norm_module, residual_module1,
            residual_module1_sep, residual_module2, residual_module2_sep,
            residual_module3, residual_module3_sep
        ]
        inputs = sharded_features["inputs"]

        cur = tf.concat(inputs, axis=0)
        cur_shape = cur.get_shape()
        for i in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % i):
                processed = run_modules(blocks, cur, hparams, train, dp)
                cur = common_layers.shakeshake(processed)
                cur.set_shape(cur_shape)

        return list(tf.split(cur, len(inputs), axis=0)), 0.0