def testShakeShake(self): x = np.random.rand(5, 7) with self.test_session() as session: x = tf.constant(x, dtype=tf.float32) y = common_layers.shakeshake([x, x, x, x, x]) session.run(tf.global_variables_initializer()) inp, res = session.run([x, y]) self.assertAllClose(res, inp)
def model_fn_body_sharded(self, sharded_features, train): dp = self._data_parallelism dp._reuse = False # pylint:disable=protected-access hparams = self._hparams blocks = [ identity_module, norm_module, residual_module1, residual_module1_sep, residual_module2, residual_module2_sep, residual_module3, residual_module3_sep ] inputs = sharded_features["inputs"] cur = tf.concat(inputs, axis=0) cur_shape = cur.get_shape() for i in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % i): processed = run_modules(blocks, cur, hparams, train, dp) cur = common_layers.shakeshake(processed) cur.set_shape(cur_shape) return list(tf.split(cur, len(inputs), axis=0)), 0.0