예제 #1
0
    def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):

        # note this should handle transforms correctly via distribution_to_data
        raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist()
        for d, name in zip(range(len(sample_inputs), 0, -1),
                           sample_inputs.keys()):
            dim_to_name[-d - len(raw_dist.batch_shape)] = name

        if value_name not in sampled_vars:
            return self

        sample_shape = tuple(v.size for v in sample_inputs.values())
        sample_args = (sample_shape, ) if get_backend() == "torch" else (
            rng_key, sample_shape)
        if self.has_rsample:
            raw_value = raw_dist.rsample(*sample_args)
        else:
            raw_value = ops.detach(raw_dist.sample(*sample_args))

        funsor_value = to_funsor(raw_value,
                                 output=value_output,
                                 dim_to_name=dim_to_name)
        funsor_value = funsor_value.align(
            tuple(sample_inputs) +
            tuple(inp for inp in self.inputs if inp in funsor_value.inputs))
        result = funsor.delta.Delta(value_name, funsor_value)
        if not self.has_rsample:
            # scaling of dice_factor by num samples should already be handled by Funsor.sample
            raw_log_prob = raw_dist.log_prob(raw_value)
            dice_factor = to_funsor(raw_log_prob - ops.detach(raw_log_prob),
                                    output=self.output,
                                    dim_to_name=dim_to_name)
            result = result + dice_factor
        return result
예제 #2
0
    def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
        params = OrderedDict(self.params)
        value = params.pop("value")
        assert all(isinstance(v, (Number, Tensor)) for v in params.values())
        assert isinstance(value, Variable) and value.name in sampled_vars
        inputs_, tensors = align_tensors(*params.values())
        inputs = OrderedDict(sample_inputs.items())
        inputs.update(inputs_)
        sample_shape = tuple(v.size for v in sample_inputs.values())

        raw_dist = self.dist_class(**dict(zip(self._ast_fields[:-1], tensors)))
        sample_args = (sample_shape, ) if rng_key is None else (rng_key,
                                                                sample_shape)
        if getattr(raw_dist, "has_rsample", False):
            raw_sample = raw_dist.rsample(*sample_args)
        else:
            raw_sample = ops.detach(raw_dist.sample(*sample_args))

        result = funsor.delta.Delta(
            value.name, Tensor(raw_sample, inputs, value.output.dtype))
        if not getattr(raw_dist, "has_rsample", False):
            # scaling of dice_factor by num samples should already be handled by Funsor.sample
            raw_log_prob = raw_dist.log_prob(raw_sample)
            dice_factor = Tensor(raw_log_prob - ops.detach(raw_log_prob),
                                 inputs)
            result = result + dice_factor
        return result
예제 #3
0
def einsum(equation, *operands):
    """
    Log-sum-exp implementation of einsum.
    """
    if get_backend() != "jax":
        # NB: rename symbols to support NumPy, which allow only symbols a-z.
        symbols = sorted(set(equation) - set(',->'))
        rename = dict(zip(symbols, 'abcdefghijklmnopqrstuvwxyz'))
        equation = ''.join(rename.get(s, s) for s in equation)

    inputs, output = equation.split('->')
    if inputs == output:
        return operands[0][...]  # create a new object
    inputs = inputs.split(',')

    shifts = []
    exp_operands = []
    for dims, operand in zip(inputs, operands):
        shift = ops.detach(operand)
        for i, dim in enumerate(dims):
            if dim not in output:
                shift = ops.amax(shift, i, keepdims=True)
        # avoid nan due to -inf - -inf
        shift = ops.clamp(shift, ops.finfo(shift).min, None)
        exp_operands.append(ops.exp(operand - shift))

        # permute shift to match output
        shift = shift.reshape(
            [size for size, dim in zip(operand.shape, dims) if dim in output])
        if len(shift.shape) > 0:
            shift = shift.reshape((1, ) * (len(output) - shift.ndim) +
                                  shift.shape)
            dims = [dim for dim in dims if dim in output]
            dims = [dim for dim in output if dim not in dims] + dims
            shift = ops.permute(shift, [dims.index(dim) for dim in output])
        shifts.append(shift)

    result = ops.log(ops.einsum(equation, *exp_operands))
    return sum(shifts + [result])
예제 #4
0
파일: tensor.py 프로젝트: ordabayevy/funsor
    def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
        assert self.output == Real
        sampled_vars = sampled_vars.intersection(self.inputs)
        if not sampled_vars:
            return self

        # Partition inputs into sample_inputs + batch_inputs + event_inputs.
        sample_inputs = OrderedDict(
            (k, d) for k, d in sample_inputs.items() if k not in self.inputs)
        sample_shape = tuple(int(d.dtype) for d in sample_inputs.values())
        batch_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if k not in sampled_vars)
        event_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if k in sampled_vars)
        be_inputs = batch_inputs.copy()
        be_inputs.update(event_inputs)
        sb_inputs = sample_inputs.copy()
        sb_inputs.update(batch_inputs)

        # Sample all variables in a single Categorical call.
        logits = align_tensor(be_inputs, self)
        batch_shape = logits.shape[:len(batch_inputs)]
        flat_logits = logits.reshape(batch_shape + (-1, ))
        sample_shape = tuple(d.dtype for d in sample_inputs.values())

        backend = get_backend()
        if backend != "numpy":
            from importlib import import_module
            dist = import_module(
                funsor.distribution.BACKEND_TO_DISTRIBUTIONS_BACKEND[backend])
            sample_args = (sample_shape, ) if rng_key is None else (
                rng_key, sample_shape)
            flat_sample = dist.CategoricalLogits.dist_class(
                logits=flat_logits).sample(*sample_args)
        else:  # default numpy backend
            assert backend == "numpy"
            shape = sample_shape + flat_logits.shape[:-1]
            logit_max = np.amax(flat_logits, -1, keepdims=True)
            probs = np.exp(flat_logits - logit_max)
            probs = probs / np.sum(probs, -1, keepdims=True)
            s = np.cumsum(probs, -1)
            r = np.random.rand(*shape)
            flat_sample = np.sum(s < np.expand_dims(r, -1), axis=-1)

        assert flat_sample.shape == sample_shape + batch_shape
        results = []
        mod_sample = flat_sample
        for name, domain in reversed(list(event_inputs.items())):
            size = domain.dtype
            point = Tensor(mod_sample % size, sb_inputs, size)
            mod_sample = mod_sample // size
            results.append(Delta(name, point))

        # Account for the log normalizer factor.
        # Derivation: Let f be a nonnormalized distribution (a funsor), and
        #   consider operations in linear space (source code is in log space).
        #   Let x0 ~ f/|f| be a monte carlo sample from a normalized f/|f|.
        #                              f(x0) / |f|      # dice numerator
        #   Let g = delta(x=x0) |f| -----------------
        #                           detach(f(x0)/|f|)   # dice denominator
        #                       |detach(f)| f(x0)
        #         = delta(x=x0) -----------------  be a dice approximation of f.
        #                         detach(f(x0))
        #   Then g is an unbiased estimator of f in value and all derivatives.
        #   In the special case f = detach(f), we can simplify to
        #       g = delta(x=x0) |f|.
        if (backend == "torch"
                and flat_logits.requires_grad) or backend == "jax":
            # Apply a dice factor to preserve differentiability.
            index = [
                ops.new_arange(self.data,
                               n).reshape((n, ) + (1, ) *
                                          (len(flat_logits.shape) - i - 2))
                for i, n in enumerate(flat_logits.shape[:-1])
            ]
            index.append(flat_sample)
            log_prob = flat_logits[tuple(index)]
            assert log_prob.shape == flat_sample.shape
            results.append(
                Tensor(
                    ops.logsumexp(ops.detach(flat_logits), -1) +
                    (log_prob - ops.detach(log_prob)), sb_inputs))
        else:
            # This is the special case f = detach(f).
            results.append(Tensor(ops.logsumexp(flat_logits, -1),
                                  batch_inputs))

        return reduce(ops.add, results)