예제 #1
0
파일: deduce.py 프로젝트: alexeyche/vilab
def _deduce_elements(elem, feed_dict, structure, batch_size, reuse, silent,
                     context, engine):
    log_level = logging.getLogger().level
    if silent:
        setup_log(logging.CRITICAL)

    if context is None:
        data_size = get_data_size(feed_dict)

        assert not batch_size is None or data_size is None or data_size < 10000, \
            "Got too big data size, need to point proper batch_size in arguments"

        if batch_size is None:
            assert not data_size is None, "Need to specify batch size"
            batch_size = data_size

    else:
        batch_size = context.batch_size
        structure = context.structure

    if not is_sequence(elem):
        elements = [elem]
    else:
        elements = elem

    deduce_shapes(feed_dict, structure)

    engine_to_run = engine
    if context is None:
        context = DeduceContext(engine, structure, batch_size, feed_dict)
    else:
        engine_to_run = context.engine

    p = Parser(batch_size, structure)
    deduced = p.parse(elements, engine_to_run)

    return deduced, context
예제 #2
0
    def deduce(self, elem, ctx=None):
        if ctx is None:
            ctx = self.default_ctx

        visited_value = self.get_visited_value(elem, ctx)
        if not visited_value is None:
            return visited_value

        logging.debug("level: {}, elem: {}".format(self.level, elem))
        if logging.getLogger().level == logging.DEBUG:
            setup_log(logging.DEBUG, ident_level=self.level)

        cb_to_call = [
            cb for tp, cb in self.type_callbacks.iteritems()
            if isinstance(elem, tp)
        ]

        assert len(
            cb_to_call) > 0, "Deducer got unexpected element: {}".format(elem)
        assert len(
            cb_to_call
        ) == 1, "Got too many callback matches for element: {}".format(elem)
        self.level += 1

        result = cb_to_call[0](elem, ctx)

        # self.shape_info.append(self.engine.get_shape(result))

        self.update_visited_value(elem, ctx, result)

        self.level -= 1
        if logging.getLogger().level == logging.DEBUG:
            setup_log(logging.DEBUG, ident_level=self.level)

        logging.debug("level out: {}, result: {}".format(self.level, result))
        return result
예제 #3
0
import logging

from vilab.log import setup_log
from vilab.api import *
from vilab.util import *
from vilab.deduce import deduce, maximize, Monitor
from vilab.datasets import load_toy_dataset
from vilab.env import Env

from vilab.engines.print_engine import PrintEngine
from vilab.engines.var_engine import VarEngine

setup_log(logging.INFO)

x, z = Variable("x"), Variable("z")
p, q = Model("p"), Model("q")

mlp = Function("mlp", act=softplus)

mu, var = Function("mu", mlp), Function("var", mlp)
logit = Function("logit", mlp)

q(z | x) == N(mu(x), var(x))
p(x | z) == B(logit(z))

LL = -KL(q(z | x), N0) + log(p(x | z))

x_train, x_classes = load_toy_dataset()

batch_size, ndim = x_train.shape
예제 #4
0
    def deduce_sequence_ctx(self, elements):
        seq_ctx = Parser.SequenceCtx([], [], [], [], [], [], [], [], [], set(),
                                     set(), False)

        def has_var_data(var, feed_dict):
            if not var in feed_dict:
                if isinstance(var, PartOfSequence):
                    return var.get_seq() in set([
                        k.get_seq() for k, v in feed_dict.iteritems()
                        if isinstance(k, PartOfSequence)
                    ] + [k for k in feed_dict if isinstance(k, Sequence)])
                else:
                    return False
            else:
                return True

        def get_var_shape(var, feed_dict):
            assert has_var_data(var, feed_dict)
            if isinstance(var, PartOfSequence):
                idx = var.get_idx()
                assert isinstance(idx, Index)

                if idx.get_offset() == 0:  # input data
                    assert var.get_seq(
                    ) in feed_dict, "Expecting sequence data for {}".format(
                        var)
                    assert len(
                        feed_dict[var.get_seq()].shape
                    ) == 3, "Input data for sequence must have alignment time x batch x dimension"

                    input_shape = feed_dict[var.get_seq()].shape
                    if not var.get_seq() in seq_ctx.input_var_cache:
                        seq_ctx.input_var.append(var)
                        seq_ctx.input_var_cache.add(var.get_seq())

                        provided_input = self.engine.provide_input(
                            var.get_seq().get_name(),
                            (input_shape[0], self.batch_size, input_shape[2]))

                        self.engine_inputs[provided_input] = var.get_seq()
                        seq_ctx.input_data.append(provided_input)

                    return input_shape[2:]
                elif idx.get_offset() == -1:  # state data
                    h0 = var.get_seq()[0]
                    assert h0 in feed_dict, "Expecting {} in feed dict as start value for state sequence {}".format(
                        h0, h0.get_seq())
                    h0_shape = feed_dict[h0].shape

                    if not h0 in seq_ctx.state_var_cache:
                        seq_ctx.state_var.append(var)
                        seq_ctx.state_var_cache.add(h0)

                        provided_input = self.engine.provide_input(
                            h0.get_scope_name(),
                            (self.batch_size, h0_shape[1]))

                        self.engine_inputs[provided_input] = h0

                        seq_ctx.state_start_data.append(provided_input)
                        seq_ctx.state_size.append(h0_shape[1])

                    return h0_shape[1:]
                else:
                    raise Exception(
                        "Index offset that is not 0 or -1 is not supported yet, got {}"
                        .format(idx.get_offset()))
            else:
                assert var in feed_dict
                return feed_dict[var].shape[1:]

        data_info_cb = Parser.DataInfoCb(self.data_info.get_feed_dict(),
                                         has_var_data, get_var_shape)

        log_level = logging.getLogger().level
        setup_log(logging.CRITICAL)
        var_parser = Parser(VarEngine(), elements[0], data_info_cb,
                            self.structure, self.batch_size)

        for elem in elements:
            var_seq_ctx = Parser.SequenceCtx([], [], [], [], [], [], [], [],
                                             [], set(), set(), True)

            seq_ctx.output_elem.append(
                var_parser.deduce(elem,
                                  ctx=Parser.get_ctx_with(
                                      self.default_ctx,
                                      sequence_ctx=var_seq_ctx,
                                  )))
            for ov in var_seq_ctx.output_var:
                logging.debug(
                    "Found {} as output var, adding to RNN output".format(ov))
                seq_ctx.output_var.append(ov)

        assert len(seq_ctx.output_var) == 0 or len(
            seq_ctx.state_var
        ) > 0, "Deducer failed to find any sequence related elements to calculate"

        for v, size in zip(seq_ctx.state_var, seq_ctx.state_size):
            seq = v.get_seq()
            seq_parts = seq.get_parts()
            next_idx = v.get_idx() + 1
            assert next_idx in seq_parts, "Need to define generation process for sequence {} (define {}[{}])".format(
                seq, seq, next_idx)

            output_state = seq[next_idx]
            if output_state in seq_ctx.input_var:
                seq_ctx.input_variables.remove(output_state)
            seq_ctx.output_state_var.append(output_state)
            self.structure[output_state] = size

        setup_log(log_level)
        return seq_ctx
예제 #5
0
import logging

from vilab.log import setup_log
from vilab.api import *
from vilab.util import *
from vilab.deduce import deduce, maximize, Monitor
from vilab.env import Env
from vilab.datasets import load_mnist_realval
from vilab.engines.print_engine import PrintEngine

from vilab.parser import Parser


setup_log(logging.DEBUG)

x, y, h = Sequence("x"), Sequence("y"), Sequence("h")
t = Index("t")

Function.configure(
	weight_factor = 0.1
)

f = Function("f")

y[t] == f(x[t], h[t-1])
h[t] == f(y[t])

cost = - Summation(SquaredLoss(y[t], x[t]))

###########
예제 #6
0
    def _deduce(self, element, ctx, engine):
        if self._verbose:
            logging.debug(
                "Deducing element: \n elem: {},\n ctx: \n\t{}".format(
                    element, "\n\t".join([
                        "{} -> {}".format(k, v)
                        for k, v in ctx._asdict().iteritems()
                    ])))

        cached = engine.get_cached((element, ctx.density_view))

        if not cached is None:
            logging.debug("Engine: Cache hit for {}: {}".format(
                element, cached))
            return cached
        else:
            if self._verbose:
                logging.debug(
                    "Can't find in the cache: \n elem: {},\n ctx: \n\t{}".
                    format(
                        element, "\n\t".join([
                            "{} -> {}".format(k, v)
                            for k, v in ctx._asdict().iteritems()
                        ])))

        logging.debug("Deducing element `{}`".format(element))

        self._level += 1

        if logging.getLogger().level == logging.DEBUG:
            setup_log(logging.DEBUG, ident_level=self._level)

        callback = None

        strong_type_callbacks = [
            v for k, v in self._callbacks.iteritems() if type(element) == k
        ]
        inherit_type_callbacks = [
            v for k, v in self._callbacks.iteritems()
            if isinstance(element, k)
        ]

        if len(strong_type_callbacks) > 0:
            assert len(strong_type_callbacks) == 1
            callback = strong_type_callbacks[0]
        elif len(inherit_type_callbacks) > 0:
            assert len(inherit_type_callbacks) == 1
            callback = inherit_type_callbacks[0]
        else:
            callback = self._default_callback

        result = callback(element, ctx, engine)

        self._level -= 1

        if logging.getLogger().level == logging.DEBUG:
            setup_log(logging.DEBUG, ident_level=self._level)

        logging.debug("Done: {}".format(element))

        if not isinstance(element, Variable):
            engine.cache((element, ctx.density_view), result)
        return result