def __init__( self, vars=None, batch_size=None, total_size=None, step_size=1.0, model=None, random_seed=None, minibatches=None, minibatch_tensors=None, **kwargs ): warnings.warn(EXPERIMENTAL_WARNING) model = modelcontext(model) if vars is None: vars = model.value_vars else: vars = [model.rvs_to_values.get(var, var) for var in vars] vars = inputvars(vars) self.model = model self.vars = vars self.batch_size = batch_size self.total_size = total_size _value_error( total_size != None or batch_size != None, "total_size and batch_size of training data have to be specified", ) self.expected_iter = int(total_size / batch_size) # set random stream self.random = None if random_seed is None: self.random = at_rng() else: self.random = at_rng(random_seed) self.step_size = step_size shared = make_shared_replacements(vars, model) self.updates = OrderedDict() # XXX: This needs to be refactored self.q_size = None # int(sum(v.dsize for v in self.vars)) # This seems to be the only place that `Model.flatten` is used. # TODO: Why not _actually_ flatten the variables? # E.g. `flat_vars = at.concatenate([var.ravel() for var in vars])` # or `set_subtensor` the `vars` into a `at.vector`? flat_view = model.flatten(vars) self.inarray = [flat_view.input] self.dlog_prior = prior_dlogp(vars, model, flat_view) self.dlogp_elemwise = elemwise_dlogL(vars, model, flat_view) # XXX: This needs to be refactored self.q_size = None # int(sum(v.dsize for v in self.vars)) if minibatch_tensors is not None: _check_minibatches(minibatch_tensors, minibatches) self.minibatches = minibatches # Replace input shared variables with tensors def is_shared(t): return isinstance(t, aesara.compile.sharedvalue.SharedVariable) tensors = [(t.type() if is_shared(t) else t) for t in minibatch_tensors] updates = OrderedDict( {t: t_ for t, t_ in zip(minibatch_tensors, tensors) if is_shared(t)} ) self.minibatch_tensors = tensors self.inarray += self.minibatch_tensors self.updates.update(updates) self._initialize_values() super().__init__(vars, shared)
def setup_method(self): nr.seed(self.random_seed) self.old_at_rng = at_rng() set_at_rng(RandomStream(self.random_seed))