def test_take_grad(executor_kind): data_dtype = relay.TensorType((3, 4, 5), "float64") data = relay.var("data", data_dtype) indices = relay.var("indices", relay.TensorType((relay.Any(), ), "int32")) inputs = [ _np_randn_from_type(data_dtype, scale=1e-5), np.array([1, 2], dtype="int32") ] test_inputs = [inputs[0]] # take on axis fwd_func = relay.Function([data, indices], relay.take(data, indices, axis=1)) check_grad(fwd_func, inputs=inputs, test_inputs=test_inputs, executor_kind=executor_kind) # take on flattened fwd_func = relay.Function([data, indices], relay.take(data, indices, axis=None)) check_grad(fwd_func, inputs=inputs, test_inputs=test_inputs, executor_kind=executor_kind)
def test_where_grad(): cond_type = relay.TensorType((2, 3, 4), "int32") lhs_type = relay.TensorType((1, 3, 4), "float32") rhs_type = relay.TensorType((2, 1, 4), "float32") inputs = [ np.random.randint(2, size=cond_type.concrete_shape, dtype=cond_type.dtype), _np_randn_from_type(lhs_type, scale=1e-5), _np_randn_from_type(rhs_type, scale=1e-5), ] cond = relay.var("cond", type_annotation=cond_type) lhs = relay.var("lhs", type_annotation=lhs_type) rhs = relay.var("rhs", type_annotation=rhs_type) fwd_func = relay.Function([cond, lhs, rhs], relay.where(cond, lhs, rhs)) check_grad(fwd_func, inputs=inputs, test_inputs=inputs[1:])