コード例 #1
0
ファイル: optimizer.py プロジェクト: spacexcorp/objax
    def _get_loss(self, loss_name: str):
        if loss_name == 'logistic':
            x = TrainVar(jn.zeros(2))
            model_vars = VarCollection({'x': x})

            def loss():
                return jn.log(jn.exp(-jn.sum(x.value)) + 1)

            return model_vars, loss
        if loss_name == 'square':
            # loss = x*x + y*y.
            x = TrainVar(jn.ones(2))
            y = TrainVar(jn.ones(3))
            model_vars = VarCollection({'x': x, 'y': y})

            def loss():
                return jn.dot(x.value, x.value) + jn.dot(y.value, y.value)

            return model_vars, loss
        if loss_name == 'rastrigin':
            d = 2
            x = TrainVar(jn.ones(d))
            model_vars = VarCollection({'x': x})

            def loss():
                return 10 * d + jn.dot(x.value, x.value) - 10 * jn.sum(
                    jn.cos(2 * math.pi * x.value))

            return model_vars, loss
        raise ValueError
コード例 #2
0
    def __init__(self, f: Callable, vc: VarCollection):
        """Function constructor.

        Args:
            f: the function or the module to represent.
            vc: the VarCollection of variables used by the function.
        """
        if hasattr(f, '__name__'):
            self.vc = VarCollection((f'{{{f.__name__}}}{k}', v) for k, v in vc.items())
        else:
            self.vc = VarCollection(vc)
        self.__wrapped__ = f
コード例 #3
0
ファイル: gradient.py プロジェクト: rwightman/objax
    def __init__(self, f: Callable,
                 variables: Optional[VarCollection],
                 input_argnums: Optional[Tuple[int, ...]] = None):
        """Constructs an instance to compute the gradient of f w.r.t. variables.

        Args:
            f: the function for which to compute gradients.
            variables: the variables for which to compute gradients.
            input_argnums: input indexes, if any, on which to compute gradients.
        """
        variables = variables or VarCollection()
        super().__init__(variables)
        self.input_argnums = input_argnums or tuple()

        def f_func(inputs_and_train_tensors: List[JaxArray],
                   state_tensors: List[JaxArray],
                   list_args: List):
            inputs = inputs_and_train_tensors[:len(self.input_argnums)]
            train_tensors = inputs_and_train_tensors[len(self.input_argnums):]
            original_vc = self.vc.tensors()
            self.vc.subset(TrainVar).assign(train_tensors)
            self.vc.subset(BaseState).assign(state_tensors)
            for i, arg in zip(self.input_argnums, inputs):
                list_args[i] = arg
            outputs = f(*list_args)
            if not isinstance(outputs, (list, tuple)):
                outputs = [outputs]
            return_value = outputs[0], (outputs, variables.tensors(BaseState))
            self.vc.assign(original_vc)
            return return_value

        self.f = jax.grad(f_func, has_aux=True)
コード例 #4
0
ファイル: module.py プロジェクト: spacexcorp/objax
    def vars(self, scope: str = '') -> VarCollection:
        """Collect all the variables (and their names) contained in the VarCollection.

        Args:
            scope: string to prefix to the variable names.
        Returns:
            A VarCollection of all the variables.
        """
        return VarCollection((scope + k, v) for k, v in self.vc.items())
コード例 #5
0
ファイル: module.py プロジェクト: lberrada/objax
    def vars(self, scope: str = '') -> VarCollection:
        """Collect all the variables (and their names) contained in the list and its submodules.

        Args:
            scope: string to prefix to the variable names.
        Returns:
            A VarCollection of all the variables.
        """
        vc = VarCollection()
        scope += f'({self.__class__.__name__})'
        for p, v in enumerate(self):
            if isinstance(v, BaseVar):
                vc[f'{scope}[{p}]'] = v
            elif isinstance(v, Module):
                vc.update(v.vars(scope=f'{scope}[{p}]'))
        return vc
コード例 #6
0
    def __init__(self,
                 derivative_fn: Callable,
                 f: Union[Module, Callable],
                 variables: Optional[VarCollection],
                 input_argnums: Optional[Tuple[int, ...]] = None,
                 return_all_f_outputs: bool = False):
        """Constructs an instance to compute the derivatives of f w.r.t. variables.
        Args:
            derivative_fn: JAX transformation which computes derivative.
            f: the function for which to compute derivatives.
            variables: the variables for which to compute derivatives.
            input_argnums: input indexes, if any, on which to compute derivatives.
            return_all_f_outputs: if True also return original outputs of the fuction along with derivatives.
        """
        self.f = f
        self.vc = variables = VarCollection(variables or ())
        if not isinstance(f, Module):
            f = Function(f, self.vc)

        assert isinstance(input_argnums, tuple) or input_argnums is None, \
            f'Must pass a tuple of indices to input_argnums; received {input_argnums}.'
        self.input_argnums = input_argnums or tuple()
        self.return_all_f_outputs = return_all_f_outputs

        def f_func(inputs_and_train_tensors: List[JaxArray],
                   state_tensors: List[JaxArray], list_args: List,
                   kwargs: Dict):
            inputs = inputs_and_train_tensors[:len(self.input_argnums)]
            train_tensors = inputs_and_train_tensors[len(self.input_argnums):]
            original_vc = self.vc.tensors()
            try:
                self.vc.subset(TrainVar).assign(train_tensors)
                self.vc.subset(BaseState).assign(state_tensors)
                for i, arg in zip(self.input_argnums, inputs):
                    list_args[i] = arg
                outputs = f(*list_args, **kwargs)
                if not isinstance(outputs, (list, tuple)):
                    outputs = [outputs]
                if self.return_all_f_outputs:
                    return outputs[0], (outputs, variables.tensors())
                else:
                    return outputs[0], variables.tensors()
            finally:
                self.vc.assign(original_vc)

        self._call = derivative_fn(f_func)
コード例 #7
0
    def vars(self, scope: str = '') -> VarCollection:
        """Collect all the variables (and their names) contained in the module and its submodules.
        Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList
        if you need such a feature.

        Args:
            scope: string to prefix to the variable names.
        Returns:
            A VarCollection of all the variables.
        """
        vc = VarCollection()
        scope += f'({self.__class__.__name__}).'
        for k, v in self.__dict__.items():
            if isinstance(v, BaseVar):
                vc[scope + k] = v
            elif isinstance(v, Module):
                vc.update(v.vars(scope=scope + k))
        return vc
コード例 #8
0
    def __init__(self,
                 f: Union[Module, Callable],
                 variables: Optional[VarCollection],
                 input_argnums: Optional[Tuple[int, ...]] = None):
        """Constructs an instance to compute the gradient of f w.r.t. variables.

        Args:
            f: the function for which to compute gradients.
            variables: the variables for which to compute gradients.
            input_argnums: input indexes, if any, on which to compute gradients.
        """
        self.f = f
        self.vc = variables = VarCollection(variables or ())
        if not isinstance(f, Module):
            f = Function(f, self.vc)

        def f_func(inputs_and_train_tensors: List[JaxArray],
                   state_tensors: List[JaxArray], list_args: List,
                   kwargs: Dict):
            inputs = inputs_and_train_tensors[:len(self.input_argnums)]
            train_tensors = inputs_and_train_tensors[len(self.input_argnums):]
            original_vc = self.vc.tensors()
            try:
                self.vc.subset(TrainVar).assign(train_tensors)
                self.vc.subset(BaseState).assign(state_tensors)
                for i, arg in zip(self.input_argnums, inputs):
                    list_args[i] = arg
                outputs = f(*list_args, **kwargs)
                if not isinstance(outputs, (list, tuple)):
                    outputs = [outputs]
                return outputs[0], (outputs, variables.tensors())
            finally:
                self.vc.assign(original_vc)

        assert isinstance(input_argnums, tuple) or input_argnums is None, \
            f'Must pass a tuple of indices to input_argnums; received {input_argnums}.'
        self.input_argnums = input_argnums or tuple()

        signature = inspect.signature(f)
        self.__wrapped__ = f
        self.__signature__ = signature.replace(
            return_annotation=Tuple[List[JaxArray],
                                    signature.return_annotation])
        self._call = jax.grad(f_func, has_aux=True)
コード例 #9
0
 def __init__(self, closure_vars, global_vars):
     self.closure_vars = closure_vars
     self.global_vars = global_vars
     self.vc = VarCollection()
コード例 #10
0
ファイル: gradient.py プロジェクト: qingliaowu/objax
 def vars(self, scope: str = '') -> VarCollection:
     """Return the VarCollection of the variables used."""
     if scope:
         return VarCollection((scope + k, v) for k, v in self.vc.items())
     return VarCollection(self.vc)
コード例 #11
0
ファイル: module.py プロジェクト: spacexcorp/objax
 def __init__(self, vc: VarCollection):
     super().__init__()
     self.vc = VarCollection(
         (f'({self.__class__.__name__}){k}', v) for k, v in vc.items())