Example #1
0
 def testConvShape(self, channels, filter_shape, padding, strides,
                   input_shape):
     init_fun, apply_fun = stax.Conv(channels,
                                     filter_shape,
                                     strides=strides,
                                     padding=padding)
     _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
Example #2
0
  def test_conv(self):
    order = 3
    input_shape = (1, 5, 5, 1)
    key = random.PRNGKey(0)
    # TODO(duvenaud): Check all types of padding
    init_fun, apply_fun = stax.Conv(3, (2, 2), padding='VALID')
    _, (W, b) = init_fun(key, input_shape)

    rng = self.rng()

    x = rng.randn(*input_shape)
    primals = (W, b, x)

    series_in1 = [rng.randn(*W.shape) for _ in range(order)]
    series_in2 = [rng.randn(*b.shape) for _ in range(order)]
    series_in3 = [rng.randn(*x.shape) for _ in range(order)]

    series_in = (series_in1, series_in2, series_in3)

    def f(W, b, x):
      return apply_fun((W, b), x)

    self.check_jet(f, primals, series_in, check_dtypes=False)
Example #3
0
def cnn(conv_depth=300,
        kernel_size=5,
        n_conv_layers=2,
        across_batch=False,
        add_pos_encoding=False):
    """Build convolutional neural net."""
    # Input shape: [batch x length x depth]
    if across_batch:
        extra_dim = 0
    else:
        extra_dim = 1
    layers = [ExpandDims(axis=extra_dim)]
    if add_pos_encoding:
        layers.append(positional_encoding())

    for _ in range(n_conv_layers):
        layers.append(
            stax.Conv(conv_depth, (1, kernel_size),
                      padding="same",
                      strides=(1, 1)))
        layers.append(stax.Relu)
    layers.append(AssertNonZeroShape())
    layers.append(squeeze_layer(axis=extra_dim))
    return stax.serial(*layers)
    'train with vanilla SGD.')
flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
flags.DEFINE_float('noise_multiplier', 1.1,
                   'Ratio of the standard deviation to the clipping norm')
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
flags.DEFINE_integer('batch_size', 256, 'Batch size')
flags.DEFINE_integer('epochs', 60, 'Number of epochs')
flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG')
flags.DEFINE_integer(
    'microbatches', None, 'Number of microbatches '
    '(must evenly divide batch_size)')
flags.DEFINE_string('model_dir', None, 'Model directory')


init_random_params, predict = stax.serial(
    stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
    stax.Relu,
    stax.MaxPool((2, 2), (1, 1)),
    stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
    stax.Relu,
    stax.MaxPool((2, 2), (1, 1)),
    stax.Flatten,
    stax.Dense(32),
    stax.Relu,
    stax.Dense(10),
)


def loss(params, batch):
  inputs, targets = batch
  logits = predict(params, inputs)
Example #5
0
class StaxTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_shape={shape}",
            "shape": shape
        } for shape in [(2, 3), (5, )]))
    def testRandnInitShape(self, shape):
        key = random.PRNGKey(0)
        out = stax.randn()(key, shape)
        self.assertEqual(out.shape, shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_shape={shape}",
            "shape": shape
        } for shape in [(2, 3), (2, 3, 4)]))
    def testGlorotInitShape(self, shape):
        key = random.PRNGKey(0)
        out = stax.glorot()(key, shape)
        self.assertEqual(out.shape, shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
            .format(channels, filter_shape, padding, strides, input_shape),
            "channels":
            channels,
            "filter_shape":
            filter_shape,
            "padding":
            padding,
            "strides":
            strides,
            "input_shape":
            input_shape
        } for channels in [2, 3] for filter_shape in [(1, 1), (2, 3)]
                            for padding in ["SAME", "VALID"]
                            for strides in [None, (2, 1)]
                            for input_shape in [(2, 10, 11, 1)]))
    def testConvShape(self, channels, filter_shape, padding, strides,
                      input_shape):
        init_fun, apply_fun = stax.Conv(channels,
                                        filter_shape,
                                        strides=strides,
                                        padding=padding)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
            .format(channels, filter_shape, padding, strides, input_shape),
            "channels":
            channels,
            "filter_shape":
            filter_shape,
            "padding":
            padding,
            "strides":
            strides,
            "input_shape":
            input_shape
        } for channels in [2, 3] for filter_shape in [(1, 1), (2, 3), (3, 3)]
                            for padding in ["SAME", "VALID"]
                            for strides in [None, (2, 1), (2, 2)]
                            for input_shape in [(2, 10, 11, 1)]))
    def testConvTransposeShape(self, channels, filter_shape, padding, strides,
                               input_shape):
        init_fun, apply_fun = stax.ConvTranspose(
            channels,
            filter_shape,  # 2D
            strides=strides,
            padding=padding)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
            .format(channels, filter_shape, padding, strides, input_shape),
            "channels":
            channels,
            "filter_shape":
            filter_shape,
            "padding":
            padding,
            "strides":
            strides,
            "input_shape":
            input_shape
        } for channels in [2, 3] for filter_shape in [(1, ), (2, ), (3, )]
                            for padding in ["SAME", "VALID"]
                            for strides in [None, (1, ), (2, )]
                            for input_shape in [(2, 10, 1)]))
    def testConv1DTransposeShape(self, channels, filter_shape, padding,
                                 strides, input_shape):
        init_fun, apply_fun = stax.Conv1DTranspose(channels,
                                                   filter_shape,
                                                   strides=strides,
                                                   padding=padding)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_out_dim={}_input_shape={}".format(out_dim, input_shape),
            "out_dim":
            out_dim,
            "input_shape":
            input_shape
        } for out_dim in [3, 4] for input_shape in [(2, 3), (3, 4)]))
    def testDenseShape(self, out_dim, input_shape):
        init_fun, apply_fun = stax.Dense(out_dim)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_input_shape={}_nonlinear={}".format(input_shape, nonlinear),
                "input_shape":
                input_shape,
                "nonlinear":
                nonlinear
            } for input_shape in [(2, 3), (2, 3, 4)]
            for nonlinear in ["Relu", "Sigmoid", "Elu", "LeakyRelu"]))
    def testNonlinearShape(self, input_shape, nonlinear):
        init_fun, apply_fun = getattr(stax, nonlinear)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_window_shape={}_padding={}_strides={}_input_shape={}"
            "_maxpool={}_spec={}".format(window_shape, padding, strides,
                                         input_shape, max_pool, spec),
            "window_shape":
            window_shape,
            "padding":
            padding,
            "strides":
            strides,
            "input_shape":
            input_shape,
            "max_pool":
            max_pool,
            "spec":
            spec
        } for window_shape in [(1, 1), (2, 3)] for padding in ["VALID"]
                            for strides in [None, (2, 1)]
                            for input_shape in [(2, 5, 6, 4)]
                            for max_pool in [False, True]
                            for spec in ["NHWC", "NCHW", "WHNC", "WHCN"]))
    def testPoolingShape(self, window_shape, padding, strides, input_shape,
                         max_pool, spec):
        layer = stax.MaxPool if max_pool else stax.AvgPool
        init_fun, apply_fun = layer(window_shape,
                                    padding=padding,
                                    strides=strides,
                                    spec=spec)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_shape={input_shape}",
            "input_shape": input_shape
        } for input_shape in [(2, 3), (2, 3, 4)]))
    def testFlattenShape(self, input_shape):
        init_fun, apply_fun = stax.Flatten
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": f"_input_shape={input_shape}_spec={i}",
                "input_shape": input_shape,
                "spec": spec
            } for input_shape in [(2, 5, 6, 1)]
            for i, spec in enumerate([[stax.Conv(3, (
                2, 2))], [stax.Conv(3, (2, 2)), stax.Flatten,
                          stax.Dense(4)]])))
    def testSerialComposeLayersShape(self, input_shape, spec):
        init_fun, apply_fun = stax.serial(*spec)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_input_shape={input_shape}",
            "input_shape": input_shape
        } for input_shape in [(3, 4), (2, 5, 6, 1)]))
    def testDropoutShape(self, input_shape):
        init_fun, apply_fun = stax.Dropout(0.9)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_input_shape={input_shape}",
            "input_shape": input_shape
        } for input_shape in [(3, 4), (2, 5, 6, 1)]))
    def testFanInSum(self, input_shape):
        init_fun, apply_fun = stax.FanInSum
        _CheckShapeAgreement(self, init_fun, apply_fun,
                             [input_shape, input_shape])

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_inshapes={input_shapes}_axis={axis}",
            "input_shapes": input_shapes,
            "axis": axis
        } for input_shapes, axis in [
            ([(2, 3), (2, 1)], 1),
            ([(2, 3), (2, 1)], -1),
            ([(1, 2, 4), (1, 1, 4)], 1),
        ]))
    def testFanInConcat(self, input_shapes, axis):
        init_fun, apply_fun = stax.FanInConcat(axis)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)

    def testIssue182(self):
        key = random.PRNGKey(0)
        init_fun, apply_fun = stax.Softmax
        input_shape = (10, 3)
        inputs = np.arange(30.).astype("float32").reshape(input_shape)

        out_shape, params = init_fun(key, input_shape)
        out = apply_fun(params, inputs)

        assert out_shape == out.shape
        assert np.allclose(np.sum(np.asarray(out), -1), 1.)

    def testBatchNormNoScaleOrCenter(self):
        key = random.PRNGKey(0)
        axes = (0, 1, 2)
        init_fun, apply_fun = stax.BatchNorm(axis=axes,
                                             center=False,
                                             scale=False)
        input_shape = (4, 5, 6, 7)
        inputs = random_inputs(self.rng(), input_shape)

        out_shape, params = init_fun(key, input_shape)
        out = apply_fun(params, inputs)
        means = np.mean(out, axis=(0, 1, 2))
        std_devs = np.std(out, axis=(0, 1, 2))
        assert np.allclose(means, np.zeros_like(means), atol=1e-4)
        assert np.allclose(std_devs, np.ones_like(std_devs), atol=1e-4)

    def testBatchNormShapeNHWC(self):
        key = random.PRNGKey(0)
        init_fun, apply_fun = stax.BatchNorm(axis=(0, 1, 2))
        input_shape = (4, 5, 6, 7)
        inputs = random_inputs(self.rng(), input_shape)

        out_shape, params = init_fun(key, input_shape)
        out = apply_fun(params, inputs)

        self.assertEqual(out_shape, input_shape)
        beta, gamma = params
        self.assertEqual(beta.shape, (7, ))
        self.assertEqual(gamma.shape, (7, ))
        self.assertEqual(out_shape, out.shape)

    def testBatchNormShapeNCHW(self):
        key = random.PRNGKey(0)
        # Regression test for https://github.com/google/jax/issues/461
        init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3))
        input_shape = (4, 5, 6, 7)
        inputs = random_inputs(self.rng(), input_shape)

        out_shape, params = init_fun(key, input_shape)
        out = apply_fun(params, inputs)

        self.assertEqual(out_shape, input_shape)
        beta, gamma = params
        self.assertEqual(beta.shape, (5, ))
        self.assertEqual(gamma.shape, (5, ))
        self.assertEqual(out_shape, out.shape)