示例#1
0
文件: sgmcmc.py 项目: t-triobox/pymc3
    def __init__(
        self,
        vars=None,
        batch_size=None,
        total_size=None,
        step_size=1.0,
        model=None,
        random_seed=None,
        minibatches=None,
        minibatch_tensors=None,
        **kwargs
    ):
        warnings.warn(EXPERIMENTAL_WARNING)

        model = modelcontext(model)

        if vars is None:
            vars = model.value_vars
        else:
            vars = [model.rvs_to_values.get(var, var) for var in vars]

        vars = inputvars(vars)

        self.model = model
        self.vars = vars
        self.batch_size = batch_size
        self.total_size = total_size
        _value_error(
            total_size != None or batch_size != None,
            "total_size and batch_size of training data have to be specified",
        )
        self.expected_iter = int(total_size / batch_size)

        # set random stream
        self.random = None
        if random_seed is None:
            self.random = at_rng()
        else:
            self.random = at_rng(random_seed)

        self.step_size = step_size

        shared = make_shared_replacements(vars, model)

        self.updates = OrderedDict()
        # XXX: This needs to be refactored
        self.q_size = None  # int(sum(v.dsize for v in self.vars))

        # This seems to be the only place that `Model.flatten` is used.
        # TODO: Why not _actually_ flatten the variables?
        # E.g. `flat_vars = at.concatenate([var.ravel() for var in vars])`
        # or `set_subtensor` the `vars` into a `at.vector`?

        flat_view = model.flatten(vars)
        self.inarray = [flat_view.input]

        self.dlog_prior = prior_dlogp(vars, model, flat_view)
        self.dlogp_elemwise = elemwise_dlogL(vars, model, flat_view)
        # XXX: This needs to be refactored
        self.q_size = None  # int(sum(v.dsize for v in self.vars))

        if minibatch_tensors is not None:
            _check_minibatches(minibatch_tensors, minibatches)
            self.minibatches = minibatches

            # Replace input shared variables with tensors
            def is_shared(t):
                return isinstance(t, aesara.compile.sharedvalue.SharedVariable)

            tensors = [(t.type() if is_shared(t) else t) for t in minibatch_tensors]
            updates = OrderedDict(
                {t: t_ for t, t_ in zip(minibatch_tensors, tensors) if is_shared(t)}
            )
            self.minibatch_tensors = tensors
            self.inarray += self.minibatch_tensors
            self.updates.update(updates)

        self._initialize_values()
        super().__init__(vars, shared)
示例#2
0
 def setup_method(self):
     nr.seed(self.random_seed)
     self.old_at_rng = at_rng()
     set_at_rng(RandomStream(self.random_seed))