class AnalyzeUserVariablesNodeTransformer(ast.NodeTransformer): def __init__(self, closure_vars, global_vars): self.closure_vars = closure_vars self.global_vars = global_vars self.vc = VarCollection() def check_objax_var_module(self, node): if not hasanno(node, 'value'): return v = getanno(node, 'value') v_name = getanno(node, 'name') if v is None: return if isinstance(v, Module): self.vc.update(v.vars(scope=v_name + '.')) setanno(node, 'value', None) if isinstance(v, BaseVar): if v_name in self.vc and self.vc[v_name] is not v: # This generally should not happen and probably indication of a bug somewhere. raise ValueError( f'Variable tracing failed because two variables were found with the same name {v_name}' ) else: self.vc[v_name] = v setanno(node, 'value', None) def visit_Name(self, node): node = self.generic_visit(node) if isinstance(node.ctx, ast.Load): if node.id in self.closure_vars: setanno(node, 'name', node.id) setanno(node, 'value', self.closure_vars[node.id]) self.check_objax_var_module(node) elif node.id in self.global_vars: setanno(node, 'name', node.id) setanno(node, 'value', self.global_vars[node.id]) self.check_objax_var_module(node) return node def visit_Attribute(self, node): node = self.generic_visit(node) if isinstance(node.ctx, ast.Load) and hasanno(node.value, 'value'): parent_value = getanno(node.value, 'value') if parent_value is not None and hasattr(parent_value, node.attr): setanno(node, 'name', getanno(node.value, 'name') + '.' + node.attr) setanno(node, 'value', getattr(parent_value, node.attr)) self.check_objax_var_module(node) return node
def vars(self, scope: str = '') -> VarCollection: """Collect all the variables (and their names) contained in the list and its submodules. Args: scope: string to prefix to the variable names. Returns: A VarCollection of all the variables. """ vc = VarCollection() scope += f'({self.__class__.__name__})' for p, v in enumerate(self): if isinstance(v, BaseVar): vc[f'{scope}[{p}]'] = v elif isinstance(v, Module): vc.update(v.vars(scope=f'{scope}[{p}]')) return vc
def vars(self, scope: str = '') -> VarCollection: """Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature. Args: scope: string to prefix to the variable names. Returns: A VarCollection of all the variables. """ vc = VarCollection() scope += f'({self.__class__.__name__}).' for k, v in self.__dict__.items(): if isinstance(v, BaseVar): vc[scope + k] = v elif isinstance(v, Module): vc.update(v.vars(scope=scope + k)) return vc