def create_convert_shape_node(var_shape_node,
                              slice_node=None,
                              in_control_flow=False):
    assert isinstance(var_shape_node, (gast.Attribute, gast.Subscript))

    if isinstance(var_shape_node, gast.Attribute):
        args = [ast_to_source_code(var_shape_node.value).strip()]
        # (1) A slice can be a simple number such as 1, -2, i.e. gast.Index or gast.Constant
        # (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index or gast.Constant
        # In (1) case, we pass the number as 'idx' argument in convert_var_shape
        # In (2) case, we have to make it like `convert_var_shape(x)[slice]`
        if slice_node is not None and slice_is_num(slice_node):
            args.append(ast_to_source_code(slice_node.slice).strip())

        convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format(
            ",".join(args), in_control_flow)
        api_shape_node = gast.parse(convert_var_shape_func).body[0].value

        if slice_node is not None and not slice_is_num(slice_node):
            return gast.Subscript(
                value=api_shape_node, slice=slice_node.slice, ctx=gast.Load())
        return api_shape_node

    if isinstance(var_shape_node, gast.Subscript):
        result_node = copy.deepcopy(var_shape_node)
        result_node = create_convert_shape_node(result_node.value, result_node,
                                                in_control_flow)
        return result_node
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None):
    eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', globals())".format(
        api_shape_name)
    args = [attr_shape_name, eval_exist_func]

    if slice_node is not None and slice_is_num(slice_node):
        args.append(ast_to_source_code(slice_node.slice).strip())
    choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format(
        ",".join(args))
    choose_shape_node = gast.parse(choose_shape_func).body[0].value
    if slice_node is not None and not slice_is_num(slice_node):
        return gast.Subscript(
            value=choose_shape_node, slice=slice_node.slice, ctx=gast.Load())
    return choose_shape_node
Ejemplo n.º 3
0
    def _transform_slice_to_tensor_write(self, node):
        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]

        target_name = target_node.value.id
        slice_node = target_node.slice

        if isinstance(slice_node, gast.Slice):
            pass
        elif slice_is_num(target_node):
            value_code = ast_to_source_code(node.value)
            i = "paddle.cast(" \
                "x=paddle.jit.dy2static.to_static_variable({})," \
                "dtype='int64')".format(ast_to_source_code(slice_node))
            assign_code = "{} = paddle.tensor.array_write(x={}, i={}, array={})" \
                .format(target_name, value_code, i, target_name)
            assign_node = gast.parse(assign_code).body[0]
        return assign_node