def _check_user_params(self, **kwargs): params = self._user_params = self.user_params formula = kwargs.pop('f') if params is None: return False if formula is not None: raise opvi.ParametrizationError('No formula is allowed if user params are provided') if not isinstance(params, dict): raise TypeError('params should be a dict') if not all(isinstance(k, int) for k in params.keys()): raise TypeError('params should be a dict with `int` keys') needed = set(range(len(params))) givens = set(params.keys()) if givens != needed: raise opvi.ParametrizationError( 'Passed parameters do not have a needed set of keys, ' 'they should be equal, needed {needed}, got {givens}'.format( givens=list(sorted(givens)), needed='[0, 1, ..., %d]' % len(formula.flows))) for i in needed: flow = flows.flow_for_params(params[i]) flow_keys = set(flow.__param_spec__) user_keys = set(params[i].keys()) if flow_keys != user_keys: raise opvi.ParametrizationError( 'Passed parameters for flow `{i}` ({cls}) do not have a needed set of keys, ' 'they should be equal, needed {needed}, got {givens}'.format( givens=user_keys, needed=flow_keys, i=i, cls=flow.__name__)) return True
def __init_group__(self, group): super().__init_group__(group) # objects to be resolved # 1. string formula # 2. not changed default value # 3. Formula formula = self._kwargs.get("flow", self._vfam) jitter = self._kwargs.get("jitter", 1) if formula is None or isinstance(formula, str): # case 1 and 2 has_params = self._check_user_params(f=formula) elif isinstance(formula, flows.Formula): # case 3 has_params = self._check_user_params(f=formula.formula) else: raise TypeError( "Wrong type provided for NormalizingFlow as `flow` argument, " "expected Formula or string" ) if not has_params: if formula is None: formula = self.default_flow else: formula = "-".join( flows.flow_for_params(self.user_params[i]).short_name for i in range(len(self.user_params)) ) if not isinstance(formula, flows.Formula): formula = flows.Formula(formula) if self.local: bs = -1 elif self.batched: bs = self.bdim else: bs = None self.flow = formula( dim=self.ddim, z0=self.symbolic_initial, jitter=jitter, params=self.user_params, batch_size=bs, ) self._finalize_init()
def __init_group__(self, group): super(NormalizingFlowGroup, self).__init_group__(group) # objects to be resolved # 1. string formula # 2. not changed default value # 3. Formula formula = self._kwargs.get('flow', self._vfam) jitter = self._kwargs.get('jitter', 1) if formula is None or isinstance(formula, str): # case 1 and 2 has_params = self._check_user_params(f=formula) elif isinstance(formula, flows.Formula): # case 3 has_params = self._check_user_params(f=formula.formula) else: raise TypeError('Wrong type provided for NormalizingFlow as `flow` argument, ' 'expected Formula or string') if not has_params: if formula is None: formula = self.default_flow else: formula = '-'.join( flows.flow_for_params(self.user_params[i]).short_name for i in range(len(self.user_params)) ) if not isinstance(formula, flows.Formula): formula = flows.Formula(formula) if self.local: bs = -1 elif self.batched: bs = self.bdim else: bs = None self.flow = formula( dim=self.ddim, z0=self.symbolic_initial, jitter=jitter, params=self.user_params, batch_size=bs, ) self._finalize_init()