def testConvShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.Conv(channels, filter_shape, strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
def test_conv(self): order = 3 input_shape = (1, 5, 5, 1) key = random.PRNGKey(0) # TODO(duvenaud): Check all types of padding init_fun, apply_fun = stax.Conv(3, (2, 2), padding='VALID') _, (W, b) = init_fun(key, input_shape) rng = self.rng() x = rng.randn(*input_shape) primals = (W, b, x) series_in1 = [rng.randn(*W.shape) for _ in range(order)] series_in2 = [rng.randn(*b.shape) for _ in range(order)] series_in3 = [rng.randn(*x.shape) for _ in range(order)] series_in = (series_in1, series_in2, series_in3) def f(W, b, x): return apply_fun((W, b), x) self.check_jet(f, primals, series_in, check_dtypes=False)
def cnn(conv_depth=300, kernel_size=5, n_conv_layers=2, across_batch=False, add_pos_encoding=False): """Build convolutional neural net.""" # Input shape: [batch x length x depth] if across_batch: extra_dim = 0 else: extra_dim = 1 layers = [ExpandDims(axis=extra_dim)] if add_pos_encoding: layers.append(positional_encoding()) for _ in range(n_conv_layers): layers.append( stax.Conv(conv_depth, (1, kernel_size), padding="same", strides=(1, 1))) layers.append(stax.Relu) layers.append(AssertNonZeroShape()) layers.append(squeeze_layer(axis=extra_dim)) return stax.serial(*layers)
'train with vanilla SGD.') flags.DEFINE_float('learning_rate', .15, 'Learning rate for training') flags.DEFINE_float('noise_multiplier', 1.1, 'Ratio of the standard deviation to the clipping norm') flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm') flags.DEFINE_integer('batch_size', 256, 'Batch size') flags.DEFINE_integer('epochs', 60, 'Number of epochs') flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG') flags.DEFINE_integer( 'microbatches', None, 'Number of microbatches ' '(must evenly divide batch_size)') flags.DEFINE_string('model_dir', None, 'Model directory') init_random_params, predict = stax.serial( stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)), stax.Relu, stax.MaxPool((2, 2), (1, 1)), stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)), stax.Relu, stax.MaxPool((2, 2), (1, 1)), stax.Flatten, stax.Dense(32), stax.Relu, stax.Dense(10), ) def loss(params, batch): inputs, targets = batch logits = predict(params, inputs)
class StaxTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_shape={shape}", "shape": shape } for shape in [(2, 3), (5, )])) def testRandnInitShape(self, shape): key = random.PRNGKey(0) out = stax.randn()(key, shape) self.assertEqual(out.shape, shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_shape={shape}", "shape": shape } for shape in [(2, 3), (2, 3, 4)])) def testGlorotInitShape(self, shape): key = random.PRNGKey(0) out = stax.glorot()(key, shape) self.assertEqual(out.shape, shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}" .format(channels, filter_shape, padding, strides, input_shape), "channels": channels, "filter_shape": filter_shape, "padding": padding, "strides": strides, "input_shape": input_shape } for channels in [2, 3] for filter_shape in [(1, 1), (2, 3)] for padding in ["SAME", "VALID"] for strides in [None, (2, 1)] for input_shape in [(2, 10, 11, 1)])) def testConvShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.Conv(channels, filter_shape, strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}" .format(channels, filter_shape, padding, strides, input_shape), "channels": channels, "filter_shape": filter_shape, "padding": padding, "strides": strides, "input_shape": input_shape } for channels in [2, 3] for filter_shape in [(1, 1), (2, 3), (3, 3)] for padding in ["SAME", "VALID"] for strides in [None, (2, 1), (2, 2)] for input_shape in [(2, 10, 11, 1)])) def testConvTransposeShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.ConvTranspose( channels, filter_shape, # 2D strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}" .format(channels, filter_shape, padding, strides, input_shape), "channels": channels, "filter_shape": filter_shape, "padding": padding, "strides": strides, "input_shape": input_shape } for channels in [2, 3] for filter_shape in [(1, ), (2, ), (3, )] for padding in ["SAME", "VALID"] for strides in [None, (1, ), (2, )] for input_shape in [(2, 10, 1)])) def testConv1DTransposeShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.Conv1DTranspose(channels, filter_shape, strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_out_dim={}_input_shape={}".format(out_dim, input_shape), "out_dim": out_dim, "input_shape": input_shape } for out_dim in [3, 4] for input_shape in [(2, 3), (3, 4)])) def testDenseShape(self, out_dim, input_shape): init_fun, apply_fun = stax.Dense(out_dim) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_input_shape={}_nonlinear={}".format(input_shape, nonlinear), "input_shape": input_shape, "nonlinear": nonlinear } for input_shape in [(2, 3), (2, 3, 4)] for nonlinear in ["Relu", "Sigmoid", "Elu", "LeakyRelu"])) def testNonlinearShape(self, input_shape, nonlinear): init_fun, apply_fun = getattr(stax, nonlinear) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_window_shape={}_padding={}_strides={}_input_shape={}" "_maxpool={}_spec={}".format(window_shape, padding, strides, input_shape, max_pool, spec), "window_shape": window_shape, "padding": padding, "strides": strides, "input_shape": input_shape, "max_pool": max_pool, "spec": spec } for window_shape in [(1, 1), (2, 3)] for padding in ["VALID"] for strides in [None, (2, 1)] for input_shape in [(2, 5, 6, 4)] for max_pool in [False, True] for spec in ["NHWC", "NCHW", "WHNC", "WHCN"])) def testPoolingShape(self, window_shape, padding, strides, input_shape, max_pool, spec): layer = stax.MaxPool if max_pool else stax.AvgPool init_fun, apply_fun = layer(window_shape, padding=padding, strides=strides, spec=spec) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_shape={input_shape}", "input_shape": input_shape } for input_shape in [(2, 3), (2, 3, 4)])) def testFlattenShape(self, input_shape): init_fun, apply_fun = stax.Flatten _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": f"_input_shape={input_shape}_spec={i}", "input_shape": input_shape, "spec": spec } for input_shape in [(2, 5, 6, 1)] for i, spec in enumerate([[stax.Conv(3, ( 2, 2))], [stax.Conv(3, (2, 2)), stax.Flatten, stax.Dense(4)]]))) def testSerialComposeLayersShape(self, input_shape, spec): init_fun, apply_fun = stax.serial(*spec) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_input_shape={input_shape}", "input_shape": input_shape } for input_shape in [(3, 4), (2, 5, 6, 1)])) def testDropoutShape(self, input_shape): init_fun, apply_fun = stax.Dropout(0.9) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_input_shape={input_shape}", "input_shape": input_shape } for input_shape in [(3, 4), (2, 5, 6, 1)])) def testFanInSum(self, input_shape): init_fun, apply_fun = stax.FanInSum _CheckShapeAgreement(self, init_fun, apply_fun, [input_shape, input_shape]) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_inshapes={input_shapes}_axis={axis}", "input_shapes": input_shapes, "axis": axis } for input_shapes, axis in [ ([(2, 3), (2, 1)], 1), ([(2, 3), (2, 1)], -1), ([(1, 2, 4), (1, 1, 4)], 1), ])) def testFanInConcat(self, input_shapes, axis): init_fun, apply_fun = stax.FanInConcat(axis) _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes) def testIssue182(self): key = random.PRNGKey(0) init_fun, apply_fun = stax.Softmax input_shape = (10, 3) inputs = np.arange(30.).astype("float32").reshape(input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) assert out_shape == out.shape assert np.allclose(np.sum(np.asarray(out), -1), 1.) def testBatchNormNoScaleOrCenter(self): key = random.PRNGKey(0) axes = (0, 1, 2) init_fun, apply_fun = stax.BatchNorm(axis=axes, center=False, scale=False) input_shape = (4, 5, 6, 7) inputs = random_inputs(self.rng(), input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) means = np.mean(out, axis=(0, 1, 2)) std_devs = np.std(out, axis=(0, 1, 2)) assert np.allclose(means, np.zeros_like(means), atol=1e-4) assert np.allclose(std_devs, np.ones_like(std_devs), atol=1e-4) def testBatchNormShapeNHWC(self): key = random.PRNGKey(0) init_fun, apply_fun = stax.BatchNorm(axis=(0, 1, 2)) input_shape = (4, 5, 6, 7) inputs = random_inputs(self.rng(), input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) self.assertEqual(out_shape, input_shape) beta, gamma = params self.assertEqual(beta.shape, (7, )) self.assertEqual(gamma.shape, (7, )) self.assertEqual(out_shape, out.shape) def testBatchNormShapeNCHW(self): key = random.PRNGKey(0) # Regression test for https://github.com/google/jax/issues/461 init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3)) input_shape = (4, 5, 6, 7) inputs = random_inputs(self.rng(), input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) self.assertEqual(out_shape, input_shape) beta, gamma = params self.assertEqual(beta.shape, (5, )) self.assertEqual(gamma.shape, (5, )) self.assertEqual(out_shape, out.shape)