Ejemplo n.º 1
0
    def test_row_wise_sparse_adagrad_empty(self, inputs, lr, epsilon,
                                           data_strategy, gc, dc):
        param = inputs[0]
        lr = np.array([lr], dtype=np.float32)

        momentum = data_strategy.draw(
            hu.tensor1d(min_len=param.shape[0],
                        max_len=param.shape[0],
                        elements=hu.elements_of_type(dtype=np.float32)))
        momentum = np.abs(momentum)

        grad = np.empty(shape=(0, ) + param.shape[1:], dtype=np.float32)
        indices = np.empty(shape=(0, ), dtype=np.int64)

        hypothesis.note('indices.shape: %s' % str(indices.shape))

        op = core.CreateOperator(
            "RowWiseSparseAdagrad",
            ["param", "momentum", "indices", "grad", "lr"],
            ["param", "momentum"],
            epsilon=epsilon,
            device_option=gc)

        def ref_row_wise_sparse(param, momentum, indices, grad, lr):
            param_out = np.copy(param)
            momentum_out = np.copy(momentum)
            return (param_out, momentum_out)

        self.assertReferenceChecks(gc, op,
                                   [param, momentum, indices, grad, lr],
                                   ref_row_wise_sparse)
Ejemplo n.º 2
0
class TestUniqueOps(hu.HypothesisTestCase):
    @given(
        X=hu.tensor1d(
            # allow empty
            min_len=0,
            dtype=np.int32,
            # allow negatives
            elements=st.integers(min_value=-10, max_value=10)),
        return_remapping=st.booleans(),
        **hu.gcs)
    def test_unique_op(self, X, return_remapping, gc, dc):
        # impl of unique op does not guarantees return order, sort the input
        # so different impl return same outputs
        X = np.sort(X)

        op = core.CreateOperator(
            "Unique",
            ['X'],
            ["U", "remap"] if return_remapping else ["U"],
        )
        self.assertDeviceChecks(
            device_options=dc,
            op=op,
            inputs=[X],
            outputs_to_check=[0, 1] if return_remapping else [0])
        self.assertReferenceChecks(
            device_option=gc,
            op=op,
            inputs=[X],
            reference=partial(_unique_ref, return_inverse=return_remapping),
        )
Ejemplo n.º 3
0
    def test_row_wise_sparse_adagrad_empty(self, inputs, lr, epsilon,
                                           data_strategy, gc, dc):
        param = inputs[0]
        lr = np.array([lr], dtype=np.float32)

        momentum = data_strategy.draw(
            hu.tensor1d(min_len=param.shape[0], max_len=param.shape[0],
                        elements=hu.elements_of_type(dtype=np.float32))
        )
        momentum = np.abs(momentum)

        grad = np.empty(shape=(0,) + param.shape[1:], dtype=np.float32)
        indices = np.empty(shape=(0,), dtype=np.int64)

        hypothesis.note('indices.shape: %s' % str(indices.shape))

        op = core.CreateOperator(
            "RowWiseSparseAdagrad",
            ["param", "momentum", "indices", "grad", "lr"],
            ["param", "momentum"],
            epsilon=epsilon,
            device_option=gc)

        def ref_row_wise_sparse(param, momentum, indices, grad, lr):
            param_out = np.copy(param)
            momentum_out = np.copy(momentum)
            return (param_out, momentum_out)

        self.assertReferenceChecks(
            gc, op,
            [param, momentum, indices, grad, lr],
            ref_row_wise_sparse)
Ejemplo n.º 4
0
    def test_row_wise_sparse_adagrad(self, inputs, lr, epsilon, data_strategy,
                                     gc, dc):
        param, grad = inputs
        lr = np.array([lr], dtype=np.float32)

        # Create a 1D row-wise average sum of squared gradients tensor.
        momentum = data_strategy.draw(
            hu.tensor1d(min_len=param.shape[0],
                        max_len=param.shape[0],
                        elements=hu.elements_of_type(dtype=np.float32)))
        momentum = np.abs(momentum)

        # Create an indexing array containing values which index into grad
        indices = data_strategy.draw(
            hu.tensor(dtype=np.int64,
                      elements=st.sampled_from(np.arange(grad.shape[0]))), )

        # Note that unlike SparseAdagrad, RowWiseSparseAdagrad uses a moment
        # tensor that is strictly 1-dimensional and equal in length to the
        # first dimension of the parameters, so indices must also be
        # 1-dimensional.
        indices = indices.flatten()

        hypothesis.note('indices.shape: %s' % str(indices.shape))

        # The indices must be unique
        hypothesis.assume(np.array_equal(np.unique(indices), np.sort(indices)))

        # Sparsify grad
        grad = grad[indices]

        op = core.CreateOperator(
            "RowWiseSparseAdagrad",
            ["param", "momentum", "indices", "grad", "lr"],
            ["param", "momentum"],
            epsilon=epsilon,
            device_option=gc)

        def ref_row_wise_sparse(param, momentum, indices, grad, lr):
            param_out = np.copy(param)
            momentum_out = np.copy(momentum)
            for i, index in enumerate(indices):
                param_out[index], momentum_out[
                    index] = self.ref_row_wise_adagrad(param[index],
                                                       momentum[index],
                                                       grad[i], lr, epsilon)
            return (param_out, momentum_out)

        self.assertReferenceChecks(gc, op,
                                   [param, momentum, indices, grad, lr],
                                   ref_row_wise_sparse)
Ejemplo n.º 5
0
    def test_row_wise_sparse_adagrad(self, inputs, lr, epsilon,
                                     data_strategy, gc, dc):
        param, grad = inputs
        lr = np.array([lr], dtype=np.float32)

        # Create a 1D row-wise average sum of squared gradients tensor.
        momentum = data_strategy.draw(
            hu.tensor1d(min_len=param.shape[0], max_len=param.shape[0],
                        elements=hu.elements_of_type(dtype=np.float32))
        )
        momentum = np.abs(momentum)

        # Create an indexing array containing values which index into grad
        indices = data_strategy.draw(
            hu.tensor(dtype=np.int64,
                      elements=st.sampled_from(np.arange(grad.shape[0]))),
        )

        # Note that unlike SparseAdagrad, RowWiseSparseAdagrad uses a moment
        # tensor that is strictly 1-dimensional and equal in length to the
        # first dimension of the parameters, so indices must also be
        # 1-dimensional.
        indices = indices.flatten()

        hypothesis.note('indices.shape: %s' % str(indices.shape))

        # The indices must be unique
        hypothesis.assume(np.array_equal(np.unique(indices), np.sort(indices)))

        # Sparsify grad
        grad = grad[indices]

        op = core.CreateOperator(
            "RowWiseSparseAdagrad",
            ["param", "momentum", "indices", "grad", "lr"],
            ["param", "momentum"],
            epsilon=epsilon,
            device_option=gc)

        def ref_row_wise_sparse(param, momentum, indices, grad, lr):
            param_out = np.copy(param)
            momentum_out = np.copy(momentum)
            for i, index in enumerate(indices):
                param_out[index], momentum_out[index] = self.ref_row_wise_adagrad(
                    param[index], momentum[index], grad[i], lr, epsilon)
            return (param_out, momentum_out)

        self.assertReferenceChecks(
            gc, op,
            [param, momentum, indices, grad, lr],
            ref_row_wise_sparse)
Ejemplo n.º 6
0
class TestBooleanMaskOp(serial.SerializedTestCase):
    @given(x=hu.tensor1d(min_len=1,
                         max_len=100,
                         elements=st.floats(min_value=0.5, max_value=1.0)),
           **hu.gcs_cpu_only)
    def test_boolean_mask_gradient(self, x, gc, dc):
        op = core.CreateOperator("BooleanMask", ["data", "mask"],
                                 "masked_data")
        mask = np.random.choice(a=[True, False], size=x.shape[0])
        expected_gradient = np.copy(mask).astype(int)
        self.assertDeviceChecks(dc, op, [x, mask], [0])
        self.assertGradientChecks(gc, op, [x, mask], 0, [0])

    @given(x=hu.tensor1d(min_len=1,
                         max_len=5,
                         elements=st.floats(min_value=0.5, max_value=1.0)),
           **hu.gcs)
    def test_boolean_mask(self, x, gc, dc):
        op = core.CreateOperator("BooleanMask", ["data", "mask"],
                                 "masked_data")
        mask = np.random.choice(a=[True, False], size=x.shape[0])

        def ref(x, mask):
            return (x[mask], )

        self.assertReferenceChecks(gc, op, [x, mask], ref)
        self.assertDeviceChecks(dc, op, [x, mask], [0])

    @given(x=hu.tensor1d(min_len=1,
                         max_len=5,
                         elements=st.floats(min_value=0.5, max_value=1.0)),
           **hu.gcs)
    def test_boolean_mask_indices(self, x, gc, dc):
        op = core.CreateOperator("BooleanMask", ["data", "mask"],
                                 ["masked_data", "masked_indices"])
        mask = np.random.choice(a=[True, False], size=x.shape[0])

        def ref(x, mask):
            return (x[mask], np.where(mask)[0])

        self.assertReferenceChecks(gc, op, [x, mask], ref)
        self.assertDeviceChecks(dc, op, [x, mask], [0])

    @staticmethod
    def _dtype_conversion(x, dtype, gc, dc):
        """SequenceMask only supports fp16 with CUDA/ROCm."""
        if dtype == np.float16:
            assume(core.IsGPUDeviceType(gc.device_type))
            dc = [d for d in dc if core.IsGPUDeviceType(d.device_type)]
            x = x.astype(dtype)
        return x, dc

    @given(x=hu.tensor(min_dim=2,
                       max_dim=5,
                       elements=st.floats(min_value=0.5, max_value=1.0)),
           dtype=st.sampled_from([np.float32, np.float16]),
           **hu.gcs)
    def test_sequence_mask_with_lengths(self, x, dtype, gc, dc):
        x, dc = self._dtype_conversion(x, dtype, gc, dc)
        # finite fill value needed for gradient check
        fill_val = 1e-3 if dtype == np.float16 else 1e-9
        op = core.CreateOperator("SequenceMask", ["data", "lengths"],
                                 ["masked_data"],
                                 mode="sequence",
                                 axis=len(x.shape) - 1,
                                 fill_val=fill_val)
        elem_dim = x.shape[-1]
        leading_dim = 1
        for dim in x.shape[:-1]:
            leading_dim *= dim
        lengths = np.random.randint(0, elem_dim, [leading_dim])\
            .astype(np.int32)

        def ref(x, lengths):
            ref = np.reshape(x, [leading_dim, elem_dim])
            for i in range(leading_dim):
                for j in range(elem_dim):
                    if j >= lengths[i]:
                        ref[i, j] = fill_val
            return [ref.reshape(x.shape)]

        self.assertReferenceChecks(gc, op, [x, lengths], ref)
        self.assertDeviceChecks(dc, op, [x, lengths], [0])

    @given(x=hu.tensor(min_dim=2,
                       max_dim=5,
                       elements=st.floats(min_value=0.5, max_value=1.0)),
           dtype=st.sampled_from([np.float32, np.float16]),
           **hu.gcs)
    def test_sequence_mask_with_window(self, x, dtype, gc, dc):
        x, dc = self._dtype_conversion(x, dtype, gc, dc)
        # finite fill value needed for gradient check
        fill_val = 1e-3 if dtype == np.float16 else 1e-9
        radius = 2
        op = core.CreateOperator("SequenceMask", ["data", "centers"],
                                 ["masked_data"],
                                 mode="window",
                                 radius=radius,
                                 axis=len(x.shape) - 1,
                                 fill_val=fill_val)
        elem_dim = x.shape[-1]
        leading_dim = 1
        for dim in x.shape[:-1]:
            leading_dim *= dim
        centers = np.random.randint(0, elem_dim, [leading_dim])\
            .astype(np.int32)

        def ref(x, centers):
            ref = np.reshape(x, [leading_dim, elem_dim])
            for i in range(leading_dim):
                for j in range(elem_dim):
                    if j > centers[i] + radius or j < centers[i] - radius:
                        ref[i, j] = fill_val
            return [ref.reshape(x.shape)]

        self.assertReferenceChecks(gc, op, [x, centers], ref)
        self.assertDeviceChecks(dc, op, [x, centers], [0])

        # Gradient check with np.float16 is found to be flakey, disable for now
        # with high threshold (to repro, set threshold to 0.4).
        threshold = 1.0 if dtype == np.float16 else 0.005
        self.assertGradientChecks(gc,
                                  op, [x, centers],
                                  0, [0],
                                  threshold=threshold)

    @given(x=hu.tensor(min_dim=2,
                       max_dim=5,
                       elements=st.floats(min_value=0.5, max_value=1.0)),
           mode=st.sampled_from(['upper', 'lower', 'upperdiag', 'lowerdiag']),
           dtype=st.sampled_from([np.float32, np.float16]),
           **hu.gcs)
    def test_sequence_mask_triangle(self, x, mode, dtype, gc, dc):
        x, dc = self._dtype_conversion(x, dtype, gc, dc)
        # finite fill value needed for gradient check
        fill_val = 1e-3 if dtype == np.float16 else 1e-9
        op = core.CreateOperator("SequenceMask", ["data"], ["masked_data"],
                                 mode=mode,
                                 axis=len(x.shape) - 1,
                                 fill_val=fill_val)
        elem_dim = x.shape[-1]
        leading_dim = 1
        for dim in x.shape[:-1]:
            leading_dim *= dim

        if mode == 'upper':

            def compare(i, j):
                return j > i
        elif mode == 'lower':

            def compare(i, j):
                return j < i
        elif mode == 'upperdiag':

            def compare(i, j):
                return j >= i
        elif mode == 'lowerdiag':

            def compare(i, j):
                return j <= i

        def ref(x):
            ref = np.reshape(x, [leading_dim, elem_dim])
            for i in range(leading_dim):
                for j in range(elem_dim):
                    if compare(i, j):
                        ref[i, j] = fill_val
            return [ref.reshape(x.shape)]

        self.assertReferenceChecks(gc, op, [x], ref)
        self.assertDeviceChecks(dc, op, [x], [0])

        # Gradient check with np.float16 is found to be flakey, disable for now
        # with high threshold (to repro, set threshold to 0.4).
        threshold = 1.0 if dtype == np.float16 else 0.005
        stepsize = 0.1 if dtype == np.float16 else 0.05
        self.assertGradientChecks(gc,
                                  op, [x],
                                  0, [0],
                                  threshold=threshold,
                                  stepsize=stepsize)

    @given(x=hu.tensor(min_dim=2,
                       max_dim=5,
                       elements=st.floats(min_value=0.5, max_value=1.0)),
           dtype=st.sampled_from([np.float32, np.float16]),
           **hu.gcs)
    def test_sequence_mask_batching_lengths(self, x, dtype, gc, dc):
        x, dc = self._dtype_conversion(x, dtype, gc, dc)
        # finite fill value needed for gradient check
        fill_val = 1e-3 if dtype == np.float16 else 1e-9
        # choose _different_ batch and axis dimensions, w/ axis != 0.
        axis = 0
        batch = 0
        while axis == 0 or axis < batch:
            inds = np.arange(len(x.shape))
            np.random.shuffle(inds)
            batch = inds[0]
            axis = inds[1]
        op = core.CreateOperator("SequenceMask", ["data", "lengths"],
                                 ["masked_data"],
                                 mode='sequence',
                                 axis=axis,
                                 fill_val=fill_val,
                                 batch=batch)

        before = int(np.prod(x.shape[:batch + 1]))
        between = int(np.prod(x.shape[batch + 1:axis]))
        after = int(np.prod(x.shape[axis:]))

        lengths = np.random.randint(0, after, [between])\
            .astype(np.int32)

        def ref(z, l):
            w = np.reshape(z, [before, between, after])

            for b in range(before):
                r = w[b, :, :]
                for i in range(between):
                    for j in range(after):
                        if j >= l[i]:
                            r[i, j] = fill_val
            return [w.reshape(z.shape)]

        self.assertReferenceChecks(gc, op, [x, lengths], ref)
        self.assertDeviceChecks(dc, op, [x, lengths], [0])

        # Gradient check with np.float16 is found to be flakey, disable for now
        # with high threshold (to repro, set threshold to 0.4).
        threshold = 1.0 if dtype == np.float16 else 0.005
        self.assertGradientChecks(gc,
                                  op, [x, lengths],
                                  0, [0],
                                  threshold=threshold)

    @given(x=hu.tensor(min_dim=4,
                       max_dim=4,
                       elements=st.floats(min_value=0.5, max_value=1.0)),
           dtype=st.sampled_from([np.float32, np.float16]),
           **hu.gcs)
    def test_sequence_mask_batching_window(self, x, dtype, gc, dc):
        x, dc = self._dtype_conversion(x, dtype, gc, dc)
        # finite fill value needed for gradient check
        fill_val = 1e-3 if dtype == np.float16 else 1e-9
        radius = 1
        # choose _different_ batch and axis dimensions, w/ axis != 0.
        axis = 0
        batch = 0
        while axis == 0 or axis < batch:
            inds = np.arange(len(x.shape))
            np.random.shuffle(inds)
            batch = inds[0]
            axis = inds[1]
        op = core.CreateOperator("SequenceMask", ["data", "centers"],
                                 ["masked_data"],
                                 mode='window',
                                 radius=radius,
                                 axis=axis,
                                 fill_val=fill_val,
                                 batch=batch)

        before = int(np.prod(x.shape[:batch + 1]))
        between = int(np.prod(x.shape[batch + 1:axis]))
        after = int(np.prod(x.shape[axis:]))

        centers = np.random.randint(0, after, [between])\
            .astype(np.int32)

        def ref(z, c):
            w = np.reshape(z, [before, between, after])

            for b in range(before):
                r = w[b, :, :]
                for i in range(between):
                    for j in range(after):
                        if j > c[i] + radius or j < c[i] - radius:
                            r[i, j] = fill_val
            return [w.reshape(z.shape)]

        self.assertReferenceChecks(gc, op, [x, centers], ref)
        self.assertDeviceChecks(dc, op, [x, centers], [0])

        # Gradient check with np.float16 is found to be flakey, disable for now
        # with high threshold (to repro, set threshold to 0.4).
        threshold = 1.0 if dtype == np.float16 else 0.005
        self.assertGradientChecks(gc,
                                  op, [x, centers],
                                  0, [0],
                                  threshold=threshold)

    @given(x=hu.tensor(min_dim=3,
                       max_dim=5,
                       elements=st.floats(min_value=0.5, max_value=1.0)),
           mode=st.sampled_from(['upper', 'lower', 'upperdiag', 'lowerdiag']),
           dtype=st.sampled_from([np.float32, np.float16]),
           **hu.gcs)
    def test_sequence_mask_batching_triangle(self, x, mode, dtype, gc, dc):
        x, dc = self._dtype_conversion(x, dtype, gc, dc)
        # finite fill value needed for gradient check
        fill_val = 1e-3 if dtype == np.float16 else 1e-9
        # choose _different_ batch and axis dimensions, w/ axis != 0.
        axis = 0
        batch = 0
        while axis == 0 or axis < batch:
            inds = np.arange(len(x.shape))
            np.random.shuffle(inds)
            batch = inds[0]
            axis = inds[1]
        op = core.CreateOperator("SequenceMask", ["data"], ["masked_data"],
                                 mode=mode,
                                 axis=axis,
                                 fill_val=fill_val,
                                 batch=batch)

        if mode == 'upper':

            def compare(i, j):
                return j > i
        elif mode == 'lower':

            def compare(i, j):
                return j < i
        elif mode == 'upperdiag':

            def compare(i, j):
                return j >= i
        elif mode == 'lowerdiag':

            def compare(i, j):
                return j <= i

        def ref(z):
            before = int(np.prod(z.shape[:batch + 1]))
            between = int(np.prod(z.shape[batch + 1:axis]))
            after = int(np.prod(z.shape[axis:]))

            w = np.reshape(z, [before, between, after])

            for b in range(before):
                r = w[b, :, :]
                for i in range(between):
                    for j in range(after):
                        if compare(i, j):
                            r[i, j] = fill_val
            return [w.reshape(z.shape)]

        self.assertReferenceChecks(gc, op, [x], ref)
        self.assertDeviceChecks(dc, op, [x], [0])

        # Gradient check with np.float16 is found to be flakey, disable for now
        # with high threshold (to repro, set threshold to 0.4).
        threshold = 1.0 if dtype == np.float16 else 0.005
        stepsize = 0.1 if dtype == np.float16 else 0.05
        self.assertGradientChecks(gc,
                                  op, [x],
                                  0, [0],
                                  threshold=threshold,
                                  stepsize=stepsize)

    @given(x=hu.tensor(min_dim=3,
                       max_dim=5,
                       elements=st.floats(min_value=0.5, max_value=1.0)),
           dtype=st.sampled_from([np.float32, np.float16]),
           **hu.gcs)
    def test_sequence_mask_repeated(self, x, dtype, gc, dc):
        x, dc = self._dtype_conversion(x, dtype, gc, dc)
        # finite fill value needed for gradient check
        fill_val = 1e-3 if dtype == np.float16 else 1e-9
        op = core.CreateOperator("SequenceMask", ["data", "lengths"],
                                 ["masked_data"],
                                 mode="sequence",
                                 axis=len(x.shape) - 2,
                                 repeat_from_axis=-1,
                                 fill_val=fill_val)

        elem_dim = x.shape[-2]
        leading_dim = 1
        for dim in x.shape[:-2]:
            leading_dim *= dim
        lengths = np.random.randint(0, elem_dim, [leading_dim])\
            .astype(np.int32)

        def ref(x, lengths):
            ref = np.reshape(x, [leading_dim, elem_dim, -1])
            for i in range(leading_dim):
                for j in range(elem_dim):
                    if j >= lengths[i]:
                        ref[i, j, :] = fill_val
            return [ref.reshape(x.shape)]

        self.assertReferenceChecks(gc, op, [x, lengths], ref)
        self.assertDeviceChecks(dc, op, [x, lengths], [0])
Ejemplo n.º 7
0
    def test_row_wise_sparse_adam(self, inputs, ITER, LR, beta1, beta2,
                                  epsilon, data_strategy, gc, dc):
        param, mom1, grad = inputs
        ITER = np.array([ITER], dtype=np.int64)
        LR = np.array([LR], dtype=np.float32)

        # Create a 1D row-wise average 2nd moment tensor.
        mom2 = data_strategy.draw(
            hu.tensor1d(min_len=param.shape[0],
                        max_len=param.shape[0],
                        elements=hu.elements_of_type(dtype=np.float32)))
        mom2 = np.absolute(mom2)

        # Create an indexing array containing values which index into grad
        indices = data_strategy.draw(
            hu.tensor(
                max_dim=1,
                min_value=1,
                max_value=grad.shape[0],
                dtype=np.int64,
                elements=st.sampled_from(np.arange(grad.shape[0])),
            ), )

        # Note that unlike SparseAdam, RowWiseSparseAdam uses a moment
        # tensor that is strictly 1-dimensional and equal in length to the
        # first dimension of the parameters, so indices must also be
        # 1-dimensional.
        indices = indices.flatten()

        hypothesis.note('indices.shape: %s' % str(indices.shape))

        # Verify that the generated indices are unique
        hypothesis.assume(np.array_equal(np.unique(indices), np.sort(indices)))

        # Sparsify grad
        grad = grad[indices]

        op = core.CreateOperator(
            "RowWiseSparseAdam",
            ["param", "mom1", "mom2", "indices", "grad", "lr", "iter"],
            ["param", "mom1", "mom2"],
            beta1=beta1,
            beta2=beta2,
            epsilon=epsilon)

        def ref_row_wise_sparse(param, mom1, mom2, indices, grad, LR, ITER):
            param_out = np.copy(param)
            mom1_out = np.copy(mom1)
            mom2_out = np.copy(mom2)
            for i, index in enumerate(indices):
                param_out[index], mom1_out[index], mom2_out[index] = \
                    self.ref_row_wise_adam(param[index], mom1[index], mom2[index],
                                           grad[i], LR, ITER,
                                           beta1, beta2, epsilon)
            return (param_out, mom1_out, mom2_out)

        # Iter lives on the CPU
        input_device_options = {'iter': hu.cpu_do}

        self.assertReferenceChecks(
            gc,
            op, [param, mom1, mom2, indices, grad, LR, ITER],
            ref_row_wise_sparse,
            input_device_options=input_device_options)
Ejemplo n.º 8
0
    def test_row_wise_sparse_adam(self, inputs, ITER, LR, beta1, beta2, epsilon,
                                  data_strategy, gc, dc):
        param, mom1, grad = inputs
        ITER = np.array([ITER], dtype=np.int64)
        LR = np.array([LR], dtype=np.float32)

        # Create a 1D row-wise average 2nd moment tensor.
        mom2 = data_strategy.draw(
            hu.tensor1d(min_len=param.shape[0], max_len=param.shape[0],
                        elements=hu.elements_of_type(dtype=np.float32))
        )
        mom2 = np.absolute(mom2)

        # Create an indexing array containing values which index into grad
        indices = data_strategy.draw(
            hu.tensor(
                max_dim=1,
                min_value=1,
                max_value=grad.shape[0],
                dtype=np.int64,
                elements=st.sampled_from(np.arange(grad.shape[0])),
            ),
        )

        # Note that unlike SparseAdam, RowWiseSparseAdam uses a moment
        # tensor that is strictly 1-dimensional and equal in length to the
        # first dimension of the parameters, so indices must also be
        # 1-dimensional.
        indices = indices.flatten()

        hypothesis.note('indices.shape: %s' % str(indices.shape))

        # Verify that the generated indices are unique
        hypothesis.assume(np.array_equal(np.unique(indices), np.sort(indices)))

        # Sparsify grad
        grad = grad[indices]

        op = core.CreateOperator(
            "RowWiseSparseAdam",
            ["param", "mom1", "mom2", "indices", "grad", "lr", "iter"],
            ["param", "mom1", "mom2"],
            beta1=beta1, beta2=beta2, epsilon=epsilon)

        def ref_row_wise_sparse(param, mom1, mom2, indices, grad, LR, ITER):
            param_out = np.copy(param)
            mom1_out = np.copy(mom1)
            mom2_out = np.copy(mom2)
            for i, index in enumerate(indices):
                param_out[index], mom1_out[index], mom2_out[index] = \
                    self.ref_row_wise_adam(param[index], mom1[index], mom2[index],
                                           grad[i], LR, ITER,
                                           beta1, beta2, epsilon)
            return (param_out, mom1_out, mom2_out)

        # Iter lives on the CPU
        input_device_options = {'iter': hu.cpu_do}

        self.assertReferenceChecks(
            gc, op,
            [param, mom1, mom2, indices, grad, LR, ITER],
            ref_row_wise_sparse,
            input_device_options=input_device_options)