def create_convert_shape_node(var_shape_node, slice_node=None, in_control_flow=False): assert isinstance(var_shape_node, (gast.Attribute, gast.Subscript)) if isinstance(var_shape_node, gast.Attribute): args = [ast_to_source_code(var_shape_node.value).strip()] # (1) A slice can be a simple number such as 1, -2, i.e. gast.Index or gast.Constant # (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index or gast.Constant # In (1) case, we pass the number as 'idx' argument in convert_var_shape # In (2) case, we have to make it like `convert_var_shape(x)[slice]` if slice_node is not None and slice_is_num(slice_node): args.append(ast_to_source_code(slice_node.slice).strip()) convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format( ",".join(args), in_control_flow) api_shape_node = gast.parse(convert_var_shape_func).body[0].value if slice_node is not None and not slice_is_num(slice_node): return gast.Subscript( value=api_shape_node, slice=slice_node.slice, ctx=gast.Load()) return api_shape_node if isinstance(var_shape_node, gast.Subscript): result_node = copy.deepcopy(var_shape_node) result_node = create_convert_shape_node(result_node.value, result_node, in_control_flow) return result_node
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None): eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', globals())".format( api_shape_name) args = [attr_shape_name, eval_exist_func] if slice_node is not None and slice_is_num(slice_node): args.append(ast_to_source_code(slice_node.slice).strip()) choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format( ",".join(args)) choose_shape_node = gast.parse(choose_shape_func).body[0].value if slice_node is not None and not slice_is_num(slice_node): return gast.Subscript( value=choose_shape_node, slice=slice_node.slice, ctx=gast.Load()) return choose_shape_node
def _transform_slice_to_tensor_write(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] target_name = target_node.value.id slice_node = target_node.slice if isinstance(slice_node, gast.Slice): pass elif slice_is_num(target_node): value_code = ast_to_source_code(node.value) i = "paddle.cast(" \ "x=paddle.jit.dy2static.to_static_variable({})," \ "dtype='int64')".format(ast_to_source_code(slice_node)) assign_code = "{} = paddle.tensor.array_write(x={}, i={}, array={})" \ .format(target_name, value_code, i, target_name) assign_node = gast.parse(assign_code).body[0] return assign_node