def _get_loss(self, loss_name: str): if loss_name == 'logistic': x = TrainVar(jn.zeros(2)) model_vars = VarCollection({'x': x}) def loss(): return jn.log(jn.exp(-jn.sum(x.value)) + 1) return model_vars, loss if loss_name == 'square': # loss = x*x + y*y. x = TrainVar(jn.ones(2)) y = TrainVar(jn.ones(3)) model_vars = VarCollection({'x': x, 'y': y}) def loss(): return jn.dot(x.value, x.value) + jn.dot(y.value, y.value) return model_vars, loss if loss_name == 'rastrigin': d = 2 x = TrainVar(jn.ones(d)) model_vars = VarCollection({'x': x}) def loss(): return 10 * d + jn.dot(x.value, x.value) - 10 * jn.sum( jn.cos(2 * math.pi * x.value)) return model_vars, loss raise ValueError
def __init__(self, f: Callable, vc: VarCollection): """Function constructor. Args: f: the function or the module to represent. vc: the VarCollection of variables used by the function. """ if hasattr(f, '__name__'): self.vc = VarCollection((f'{{{f.__name__}}}{k}', v) for k, v in vc.items()) else: self.vc = VarCollection(vc) self.__wrapped__ = f
class AnalyzeUserVariablesNodeTransformer(ast.NodeTransformer): def __init__(self, closure_vars, global_vars): self.closure_vars = closure_vars self.global_vars = global_vars self.vc = VarCollection() def check_objax_var_module(self, node): if not hasanno(node, 'value'): return v = getanno(node, 'value') v_name = getanno(node, 'name') if v is None: return if isinstance(v, Module): self.vc.update(v.vars(scope=v_name + '.')) setanno(node, 'value', None) if isinstance(v, BaseVar): if v_name in self.vc and self.vc[v_name] is not v: # This generally should not happen and probably indication of a bug somewhere. raise ValueError( f'Variable tracing failed because two variables were found with the same name {v_name}' ) else: self.vc[v_name] = v setanno(node, 'value', None) def visit_Name(self, node): node = self.generic_visit(node) if isinstance(node.ctx, ast.Load): if node.id in self.closure_vars: setanno(node, 'name', node.id) setanno(node, 'value', self.closure_vars[node.id]) self.check_objax_var_module(node) elif node.id in self.global_vars: setanno(node, 'name', node.id) setanno(node, 'value', self.global_vars[node.id]) self.check_objax_var_module(node) return node def visit_Attribute(self, node): node = self.generic_visit(node) if isinstance(node.ctx, ast.Load) and hasanno(node.value, 'value'): parent_value = getanno(node.value, 'value') if parent_value is not None and hasattr(parent_value, node.attr): setanno(node, 'name', getanno(node.value, 'name') + '.' + node.attr) setanno(node, 'value', getattr(parent_value, node.attr)) self.check_objax_var_module(node) return node
class Function(Module): """Turn a function into a Module by keeping the vars it uses.""" def __init__(self, f: Callable, vc: VarCollection): """Function constructor. Args: f: the function or the module to represent. vc: the VarCollection of variables used by the function. """ if hasattr(f, '__name__'): self.vc = VarCollection( (f'{{{f.__name__}}}.{k}', v) for k, v in vc.items()) else: self.vc = VarCollection(vc) self.__wrapped__ = f def __call__(self, *args, **kwargs): """Call the the function.""" return self.__wrapped__(*args, **kwargs) def vars(self, scope: str = '') -> VarCollection: """Return the VarCollection of the variables used by the function.""" if scope: return VarCollection((scope + k, v) for k, v in self.vc.items()) return VarCollection(self.vc) @staticmethod def with_vars(vc: VarCollection): """Method to use as decorator in function definitions.""" def from_function(f: Callable): return Function(f, vc) return from_function
def __init__(self, f: Callable, variables: Optional[VarCollection], input_argnums: Optional[Tuple[int, ...]] = None): """Constructs an instance to compute the gradient of f w.r.t. variables. Args: f: the function for which to compute gradients. variables: the variables for which to compute gradients. input_argnums: input indexes, if any, on which to compute gradients. """ variables = variables or VarCollection() super().__init__(variables) self.input_argnums = input_argnums or tuple() def f_func(inputs_and_train_tensors: List[JaxArray], state_tensors: List[JaxArray], list_args: List): inputs = inputs_and_train_tensors[:len(self.input_argnums)] train_tensors = inputs_and_train_tensors[len(self.input_argnums):] original_vc = self.vc.tensors() self.vc.subset(TrainVar).assign(train_tensors) self.vc.subset(BaseState).assign(state_tensors) for i, arg in zip(self.input_argnums, inputs): list_args[i] = arg outputs = f(*list_args) if not isinstance(outputs, (list, tuple)): outputs = [outputs] return_value = outputs[0], (outputs, variables.tensors(BaseState)) self.vc.assign(original_vc) return return_value self.f = jax.grad(f_func, has_aux=True)
def load_var_collection(file: Union[str, IO[BinaryIO]], vc: VarCollection): """Loads values of all variables in the given variables collection from file. Values loaded from file will replace old values in the variables collection. If variable exists in the file, but does not exist in the variables collection it will be ignored. If variable exists in the variables collection, but not found in the file then exception will be raised. Args: file: filename or python file handle of the input file. vc: variables collection which will be loaded from file. Raises: ValueError: if variable from variables collection is not found in the input file. """ do_close = isinstance(file, str) if do_close: file = open(file, 'rb') data = np.load(file, allow_pickle=False) name_index = {k: str(i) for i, k in enumerate(data['names'])} name_vars = collections.defaultdict(list) for k, v in vc.items(): if isinstance(v, TrainRef): v = v.ref name_vars[v].append(k) for v, names in name_vars.items(): for name in names: index = name_index.get(name) if index is not None: v.assign(jn.array(data[index])) break else: raise ValueError(f'Missing value for variables {names}') if do_close: file.close()
def save_var_collection(file: Union[str, IO[BinaryIO]], vc: VarCollection): """Saves variables collection into file. Args: file: filename or python file handle of the file where variables will be saved. vc: variables collection which will be saved into file. """ do_close = isinstance(file, str) if do_close: filename, file = file, open( file + '.tmp', 'wb' ) # Save to a temporary in case the job is killed while saving. data, names, seen = {}, [], set() for k, v in vc.items(): if isinstance(v, TrainRef): v = v.ref if v not in seen: names.append(k) data[str(len(data))] = v.value seen.add(v) np.savez(file, names=np.array(names), **data) if do_close: file.close() os.rename( filename + '.tmp', filename ) # Atomic rename to avoid broken file (when killed while saving).
def vars(self, scope: str = '') -> VarCollection: """Collect all the variables (and their names) contained in the list and its submodules. Args: scope: string to prefix to the variable names. Returns: A VarCollection of all the variables. """ vc = VarCollection() scope += f'({self.__class__.__name__})' for p, v in enumerate(self): if isinstance(v, BaseVar): vc[f'{scope}[{p}]'] = v elif isinstance(v, Module): vc.update(v.vars(scope=f'{scope}[{p}]')) return vc
def __init__(self, vc: VarCollection): """Constructor for SGD optimizer. Args: vc: collection of variables to optimize. """ self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
def vars(self, scope: str = '') -> VarCollection: """Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature. Args: scope: string to prefix to the variable names. Returns: A VarCollection of all the variables. """ vc = VarCollection() scope += f'({self.__class__.__name__}).' for k, v in self.__dict__.items(): if isinstance(v, BaseVar): vc[scope + k] = v elif isinstance(v, Module): vc.update(v.vars(scope=scope + k)) return vc
def vars(self, scope: str = '') -> VarCollection: """Collect all the variables (and their names) contained in the VarCollection. Args: scope: string to prefix to the variable names. Returns: A VarCollection of all the variables. """ return VarCollection((scope + k, v) for k, v in self.vc.items())
def load_var_collection(file: Union[str, IO[BinaryIO]], vc: VarCollection, renamer: Optional[Renamer] = None): """Loads values of all variables in the given variables collection from file. Values loaded from file will replace old values in the variables collection. If variable exists in the file, but does not exist in the variables collection it will be ignored. If variable exists in the variables collection, but not found in the file then exception will be raised. Args: file: filename or python file handle of the input file. vc: variables collection which will be loaded from file. renamer: optional renamer to pre-process variables names from the file being read. Raises: ValueError: if variable from variables collection is not found in the input file. """ renamer = renamer or (lambda x: x) do_close = isinstance(file, str) if do_close: file = open(file, 'rb') data = np.load(file, allow_pickle=False) name_index = {renamer(k): str(i) for i, k in enumerate(data['names'])} var_names = collections.defaultdict(list) var_values = {} for k, v in vc.items(): if isinstance(v, TrainRef): v = v.ref var_names[id(v)].append(k) var_values[id(v)] = v misses = [] used_vars = set() for var_id, names in var_names.items(): v = var_values[var_id] for name in names: index = name_index.get(name) if index is not None: used_vars.add(name) try: v.assign(jn.array(data[index])) except AssertionError as e: raise AssertionError( f'Error when restoring variable {name}: {str(e)}' ) from None break else: misses += names if misses: not_used = set(name_index.keys()) - used_vars raise ValueError( f'Missing value for variables currently in the model: {misses}. ' f'The following variables on disk were not used, ' f'maybe the missing variable was renamed from one of these: {not_used}.' ) if do_close: file.close()
def __init__(self, vc: VarCollection, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8): """Constructor for Adam optimizer class. Args: vc: collection of variables to optimize. beta1: value of Adam's beta1 hyperparameter. Defaults to 0.9. beta2: value of Adam's beta2 hyperparameter. Defaults to 0.999. eps: value of Adam's epsilon hyperparameter. Defaults to 1e-8. """ self.beta1 = beta1 self.beta2 = beta2 self.eps = eps self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0]) self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.m = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars) self.v = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
def __init__(self, vc: VarCollection, momentum: float = 0.9, nesterov: bool = False): """Constructor for momentum optimizer class. Args: vc: collection of variables to optimize. momentum: the momentum hyperparameter. nesterov: bool indicating whether to use the Nesterov method. """ self.momentum = momentum self.nesterov = nesterov self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.m = ModuleList( StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
class ModuleWrapper(Module): """Module whose sole purpose is to store a collectable VarCollection. This class is exclusively used internally by Objax, for example in Jit, Vectorize and Parallel.""" def __init__(self, vc: VarCollection): super().__init__() self.vc = VarCollection( (f'({self.__class__.__name__}){k}', v) for k, v in vc.items()) def vars(self, scope: str = '') -> VarCollection: """Collect all the variables (and their names) contained in the VarCollection. Args: scope: string to prefix to the variable names. Returns: A VarCollection of all the variables. """ return VarCollection((scope + k, v) for k, v in self.vc.items())
def __init__(self, derivative_fn: Callable, f: Union[Module, Callable], variables: Optional[VarCollection], input_argnums: Optional[Tuple[int, ...]] = None, return_all_f_outputs: bool = False): """Constructs an instance to compute the derivatives of f w.r.t. variables. Args: derivative_fn: JAX transformation which computes derivative. f: the function for which to compute derivatives. variables: the variables for which to compute derivatives. input_argnums: input indexes, if any, on which to compute derivatives. return_all_f_outputs: if True also return original outputs of the fuction along with derivatives. """ self.f = f self.vc = variables = VarCollection(variables or ()) if not isinstance(f, Module): f = Function(f, self.vc) assert isinstance(input_argnums, tuple) or input_argnums is None, \ f'Must pass a tuple of indices to input_argnums; received {input_argnums}.' self.input_argnums = input_argnums or tuple() self.return_all_f_outputs = return_all_f_outputs def f_func(inputs_and_train_tensors: List[JaxArray], state_tensors: List[JaxArray], list_args: List, kwargs: Dict): inputs = inputs_and_train_tensors[:len(self.input_argnums)] train_tensors = inputs_and_train_tensors[len(self.input_argnums):] original_vc = self.vc.tensors() try: self.vc.subset(TrainVar).assign(train_tensors) self.vc.subset(BaseState).assign(state_tensors) for i, arg in zip(self.input_argnums, inputs): list_args[i] = arg outputs = f(*list_args, **kwargs) if not isinstance(outputs, (list, tuple)): outputs = [outputs] if self.return_all_f_outputs: return outputs[0], (outputs, variables.tensors()) else: return outputs[0], variables.tensors() finally: self.vc.assign(original_vc) self._call = derivative_fn(f_func)
def __init__(self, f: Union[Module, Callable], variables: Optional[VarCollection], input_argnums: Optional[Tuple[int, ...]] = None): """Constructs an instance to compute the gradient of f w.r.t. variables. Args: f: the function for which to compute gradients. variables: the variables for which to compute gradients. input_argnums: input indexes, if any, on which to compute gradients. """ self.f = f self.vc = variables = VarCollection(variables or ()) if not isinstance(f, Module): f = Function(f, self.vc) def f_func(inputs_and_train_tensors: List[JaxArray], state_tensors: List[JaxArray], list_args: List, kwargs: Dict): inputs = inputs_and_train_tensors[:len(self.input_argnums)] train_tensors = inputs_and_train_tensors[len(self.input_argnums):] original_vc = self.vc.tensors() try: self.vc.subset(TrainVar).assign(train_tensors) self.vc.subset(BaseState).assign(state_tensors) for i, arg in zip(self.input_argnums, inputs): list_args[i] = arg outputs = f(*list_args, **kwargs) if not isinstance(outputs, (list, tuple)): outputs = [outputs] return outputs[0], (outputs, variables.tensors()) finally: self.vc.assign(original_vc) assert isinstance(input_argnums, tuple) or input_argnums is None, \ f'Must pass a tuple of indices to input_argnums; received {input_argnums}.' self.input_argnums = input_argnums or tuple() signature = inspect.signature(f) self.__wrapped__ = f self.__signature__ = signature.replace( return_annotation=Tuple[List[JaxArray], signature.return_annotation]) self._call = jax.grad(f_func, has_aux=True)
def __init__(self, vc: VarCollection, momentum: float = 0.9, weight_decay: float = 1e-4, tc: float = 1e-3, eps: float = 1e-5): """Constructor for LARS optimizer. Args: vc: collection of variables to optimize. momentum: coefficient used for the moving average of the gradient. weight_decay: weight decay coefficient. tc: trust coefficient eta ( < 1) for trust ratio computation. eps: epsilon used for trust ratio computation. """ self.momentum = momentum self.weight_decay = weight_decay self.tc = tc self.eps = eps self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.m = ModuleList( StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
def save_var_collection(file: Union[str, IO[BinaryIO]], vc: VarCollection): """Saves variables collection into file. Args: file: filename or python file handle of the file where variables will be saved. vc: variables collection which will be saved into file. """ do_close = isinstance(file, str) if do_close: filename, file = file, open( file + '.tmp', 'wb' ) # Save to a temporary in case the job is killed while saving. data, names, seen, replicated = {}, [], set(), [] for k, v in vc.items(): if isinstance(v, TrainRef): v = v.ref if id(v) not in seen: names.append(k) data[str(len(data))] = v.value seen.add(id(v)) if isinstance(v.value, ShardedDeviceArray): replicated.append(k) if replicated: print( 'Warning: When saving VarCollection, some variables were replicated on multiple devices.' ) print( ' While it is valid, in most use cases it is more disk efficient to save variables outside of ' ) print(' vars().replicate().') np.savez(file, names=np.array(names), **data) if do_close: file.close() os.rename( filename + '.tmp', filename ) # Atomic rename to avoid broken file (when killed while saving).
def __init__(self, closure_vars, global_vars): self.closure_vars = closure_vars self.global_vars = global_vars self.vc = VarCollection()
def vars(self, scope: str = '') -> VarCollection: """Return the VarCollection of the variables used.""" if scope: return VarCollection((scope + k, v) for k, v in self.vc.items()) return VarCollection(self.vc)
def __init__(self, vc: VarCollection, base_optimizer: Callable, **kwargs): self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.base_optimizer = base_optimizer(vc, **kwargs) self.state = defaultdict(dict)
def __init__(self, vc: VarCollection): super().__init__() self.vc = VarCollection( (f'({self.__class__.__name__}){k}', v) for k, v in vc.items())