def test_take_grad(executor_kind):
    data_dtype = relay.TensorType((3, 4, 5), "float64")
    data = relay.var("data", data_dtype)
    indices = relay.var("indices", relay.TensorType((relay.Any(), ), "int32"))
    inputs = [
        _np_randn_from_type(data_dtype, scale=1e-5),
        np.array([1, 2], dtype="int32")
    ]
    test_inputs = [inputs[0]]

    # take on axis
    fwd_func = relay.Function([data, indices], relay.take(data,
                                                          indices,
                                                          axis=1))
    check_grad(fwd_func,
               inputs=inputs,
               test_inputs=test_inputs,
               executor_kind=executor_kind)

    # take on flattened
    fwd_func = relay.Function([data, indices],
                              relay.take(data, indices, axis=None))
    check_grad(fwd_func,
               inputs=inputs,
               test_inputs=test_inputs,
               executor_kind=executor_kind)
def test_where_grad():
    cond_type = relay.TensorType((2, 3, 4), "int32")
    lhs_type = relay.TensorType((1, 3, 4), "float32")
    rhs_type = relay.TensorType((2, 1, 4), "float32")
    inputs = [
        np.random.randint(2,
                          size=cond_type.concrete_shape,
                          dtype=cond_type.dtype),
        _np_randn_from_type(lhs_type, scale=1e-5),
        _np_randn_from_type(rhs_type, scale=1e-5),
    ]

    cond = relay.var("cond", type_annotation=cond_type)
    lhs = relay.var("lhs", type_annotation=lhs_type)
    rhs = relay.var("rhs", type_annotation=rhs_type)
    fwd_func = relay.Function([cond, lhs, rhs], relay.where(cond, lhs, rhs))
    check_grad(fwd_func, inputs=inputs, test_inputs=inputs[1:])