_make_harness("conv_general_dilated", "",
                  lambda lhs, rhs: lax.conv_general_dilated(lhs, rhs,
                                                            window_strides=(2, 3),
                                                            padding=((0, 0), (0, 0)),
                                                            lhs_dilation=(1, 1),
                                                            rhs_dilation=(1, 2),
                                                            dimension_numbers=("NCHW", "OIHW", "NCHW"),
                                                            feature_group_count=1,
                                                            batch_group_count=1,
                                                            precision=None),
                  [RandArg((7, 3, 9, 10), _f32), RandArg((3, 3, 4, 5), _f32)],
                  poly_axes=[0, None]),

    _make_harness("cummax", "",
                  lambda x: lax_control_flow.cummax(x, axis=1, reverse=False),
                  [RandArg((3, 4, 5), _f32)],
                  poly_axes=[0]),

    _make_harness("dot_general", "",
                  lambda lhs, rhs: lax.dot_general(lhs, rhs,
                                                   dimension_numbers=(((2,), (1,)), ((0,), (0,)))),
                  [RandArg((3, 4, 4), _f32), RandArg((3, 4), _f32)],
                  poly_axes=[0, 0]),

    _make_harness("dynamic_slice", "",
                  # x:shape: (b, 4)
                  lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)),
                  [RandArg((3, 4), _f32)],
                  poly_axes=[0]),
Exemple #2
0
 def f_jax(x):
     return lax_control_flow.cummax(x, axis=0, reverse=False)