def __init__(self, name, learnable, ranges, is_observed=False, **kwargs): #TODO: code duplication here self.name = name self._evaluated = False self._observed = is_observed self._observed_value = None self._current_value = None self.construct_deterministic_parents(learnable, ranges, kwargs) self.parents = join_sets_list( [var2link(x).vars for x in kwargs.values()]) self.ancestors = join_sets_list( [self.parents] + [parent.ancestors for parent in self.parents]) self.link = LinkConstructor(**kwargs) self.samples = None self.ranges = {} self.dataset = None self.has_random_dataset = False self.has_observed_value = False self.is_normalized = True self.partial_links = { name: var2link(link) for name, link in kwargs.items() }
def __init__(self, name, learnable, ranges, is_observed=False, has_bias=False, is_policy=False, is_reward=False, **kwargs): self._input = kwargs self.name = name self._evaluated = False self._observed = is_observed self._observed_value = None self._current_value = None #if self._check_for_stochastic_process_arguments(): # return if has_bias: self.construct_biases(learnable, ranges, kwargs) self.construct_deterministic_parents(learnable, ranges, kwargs) self.parents = join_sets_list([var2link(x).vars for x in kwargs.values()]) self.link = Link(**kwargs) self.ancestors = join_sets_list([self.parents] + [parent.ancestors for parent in self.parents]) self.samples = None self.ranges = {} self.dataset = None self.has_random_dataset = False self.has_observed_value = False self.is_normalized = True self.silenced = False self.partial_links = {name: var2link(link) for name, link in kwargs.items()} # Reinforcement learning descriptors if not (is_policy and is_reward): self._policy = is_policy self._reward = is_reward else: raise ValueError("A variable cannot be both a RL policy and a RL reward")
def __init__(self, name, learnable, ranges, is_observed=False, has_bias=False, **kwargs): self._input = kwargs.copy() self.name = name self.learnable = learnable self._bias = has_bias self._evaluated = False self._observed = is_observed self._observed_value = None self._current_value = None if self._check_for_stochastic_process_arguments(): return if has_bias: self.construct_biases(learnable, ranges, kwargs) self.construct_deterministic_parents(learnable, ranges, kwargs) self.parents = join_sets_list( [var2link(x).vars for x in kwargs.values()]) self.link = Link(**kwargs) self.ancestors = join_sets_list( [self.parents] + [parent.ancestors for parent in self.parents]) self.samples = None self.ranges = {} self.dataset = None self.has_random_dataset = False self.has_observed_value = False self.is_normalized = True self.partial_links = { name: var2link(link) for name, link in kwargs.items() }
def __call__(self, *args, **kwargs): link_args = [var2link(arg) for arg in args] link_kwargs = {name: var2link(arg) for name, arg in kwargs.items()} arg_vars = { var for link in link_args if isinstance(link, PartialLink) for var in link.vars } kwarg_vars = { var for _, link in link_kwargs.items() if isinstance(link, PartialLink) for var in link.vars } def fn(values): args = [ x.fn(values) if isinstance(x, PartialLink) else x for x in link_args ] kwargs = dict({(name, x.fn(values)) if isinstance(x, PartialLink) else (name, x) for name, x in link_kwargs.items()}) return self.fn(*args, **kwargs) return PartialLink(arg_vars.union(kwarg_vars), fn, self.links)
def __init__(self, value, name, log_determinant=None, learnable=False, has_bias=False, is_observed=False, variable_range=geometric_ranges.UnboundedRange(), is_policy=False, is_reward=False): self._type = "Deterministic node" if not isinstance(log_determinant, PartialLink): if log_determinant is None: log_determinant = torch.tensor(np.zeros((1, 1))).float().to(device) var2link(log_determinant) ranges = {"value": variable_range, "log_determinant": geometric_ranges.UnboundedRange()} super().__init__(name, value=value, log_determinant=log_determinant, learnable=learnable, has_bias=has_bias, ranges=ranges, is_observed=is_observed, is_policy=is_policy, is_reward=is_reward) self.distribution = distributions.DeterministicDistribution()
def __init__(self): self.kwargs = kwargs links = [ link for partial_link in kwargs.values() for link in var2link(partial_link).links ] super().__init__(*links)
def __init__(self, name, learnable, ranges, is_observed=False, **kwargs): class VarLink(chainer.ChainList): def __init__(self): self.kwargs = kwargs links = [ link for partial_link in kwargs.values() for link in var2link(partial_link).links ] super().__init__(*links) def __call__(self, values): return { k: var2link(x).fn(values) for k, x in self.kwargs.items() } self.name = name self._evaluated = False self._observed = is_observed self._observed_value = None self._current_value = None self.construct_deterministic_parents(learnable, ranges, kwargs) self.parents = join_sets_list( [var2link(x).vars for x in kwargs.values()]) self.link = VarLink() self.samples = [] self.ranges = {} self.dataset = None self.has_random_dataset = False self.has_observed_value = False
def __init__(self, **kwargs): self.kwargs = kwargs modules = [ link for partial_link in kwargs.values() for link in var2link(partial_link).links ] super().__init__( modules ) #TODO: asserts that specified links are valid pytorch modules
def __init__(self, function, links, *args, **kwargs): self.function = function self.link_args = [var2link(arg) for arg in args] self.link_kwargs = { name: var2link(arg) for name, arg in kwargs.items() } arg_vars = { var for link in self.link_args if isinstance(link, PartialLink) for var in link.vars } kwarg_vars = { var for _, link in self.link_kwargs.items() if isinstance(link, PartialLink) for var in link.vars } super().__init__(vars=arg_vars.union(kwarg_vars), links=links)
def __call__(self, values): return {k: var2link(x).fn(values) for k, x in self.kwargs.items()}
def _get_string(self, *args, **kwargs): return self.name + "(" + ", ".join([ var2link(a).__str__() for n, a in enumerate(list(args) + list(kwargs.values())) ]) + ")"
def __str__(self): return str(self.function) + "(" + ", ".join([ str(var2link(a)) for n, a in enumerate( list(self.link_args) + list(self.link_kwargs.values())) ]) + ")"