def _deduce_elements(elem, feed_dict, structure, batch_size, reuse, silent, context, engine): log_level = logging.getLogger().level if silent: setup_log(logging.CRITICAL) if context is None: data_size = get_data_size(feed_dict) assert not batch_size is None or data_size is None or data_size < 10000, \ "Got too big data size, need to point proper batch_size in arguments" if batch_size is None: assert not data_size is None, "Need to specify batch size" batch_size = data_size else: batch_size = context.batch_size structure = context.structure if not is_sequence(elem): elements = [elem] else: elements = elem deduce_shapes(feed_dict, structure) engine_to_run = engine if context is None: context = DeduceContext(engine, structure, batch_size, feed_dict) else: engine_to_run = context.engine p = Parser(batch_size, structure) deduced = p.parse(elements, engine_to_run) return deduced, context
def deduce(self, elem, ctx=None): if ctx is None: ctx = self.default_ctx visited_value = self.get_visited_value(elem, ctx) if not visited_value is None: return visited_value logging.debug("level: {}, elem: {}".format(self.level, elem)) if logging.getLogger().level == logging.DEBUG: setup_log(logging.DEBUG, ident_level=self.level) cb_to_call = [ cb for tp, cb in self.type_callbacks.iteritems() if isinstance(elem, tp) ] assert len( cb_to_call) > 0, "Deducer got unexpected element: {}".format(elem) assert len( cb_to_call ) == 1, "Got too many callback matches for element: {}".format(elem) self.level += 1 result = cb_to_call[0](elem, ctx) # self.shape_info.append(self.engine.get_shape(result)) self.update_visited_value(elem, ctx, result) self.level -= 1 if logging.getLogger().level == logging.DEBUG: setup_log(logging.DEBUG, ident_level=self.level) logging.debug("level out: {}, result: {}".format(self.level, result)) return result
import logging from vilab.log import setup_log from vilab.api import * from vilab.util import * from vilab.deduce import deduce, maximize, Monitor from vilab.datasets import load_toy_dataset from vilab.env import Env from vilab.engines.print_engine import PrintEngine from vilab.engines.var_engine import VarEngine setup_log(logging.INFO) x, z = Variable("x"), Variable("z") p, q = Model("p"), Model("q") mlp = Function("mlp", act=softplus) mu, var = Function("mu", mlp), Function("var", mlp) logit = Function("logit", mlp) q(z | x) == N(mu(x), var(x)) p(x | z) == B(logit(z)) LL = -KL(q(z | x), N0) + log(p(x | z)) x_train, x_classes = load_toy_dataset() batch_size, ndim = x_train.shape
def deduce_sequence_ctx(self, elements): seq_ctx = Parser.SequenceCtx([], [], [], [], [], [], [], [], [], set(), set(), False) def has_var_data(var, feed_dict): if not var in feed_dict: if isinstance(var, PartOfSequence): return var.get_seq() in set([ k.get_seq() for k, v in feed_dict.iteritems() if isinstance(k, PartOfSequence) ] + [k for k in feed_dict if isinstance(k, Sequence)]) else: return False else: return True def get_var_shape(var, feed_dict): assert has_var_data(var, feed_dict) if isinstance(var, PartOfSequence): idx = var.get_idx() assert isinstance(idx, Index) if idx.get_offset() == 0: # input data assert var.get_seq( ) in feed_dict, "Expecting sequence data for {}".format( var) assert len( feed_dict[var.get_seq()].shape ) == 3, "Input data for sequence must have alignment time x batch x dimension" input_shape = feed_dict[var.get_seq()].shape if not var.get_seq() in seq_ctx.input_var_cache: seq_ctx.input_var.append(var) seq_ctx.input_var_cache.add(var.get_seq()) provided_input = self.engine.provide_input( var.get_seq().get_name(), (input_shape[0], self.batch_size, input_shape[2])) self.engine_inputs[provided_input] = var.get_seq() seq_ctx.input_data.append(provided_input) return input_shape[2:] elif idx.get_offset() == -1: # state data h0 = var.get_seq()[0] assert h0 in feed_dict, "Expecting {} in feed dict as start value for state sequence {}".format( h0, h0.get_seq()) h0_shape = feed_dict[h0].shape if not h0 in seq_ctx.state_var_cache: seq_ctx.state_var.append(var) seq_ctx.state_var_cache.add(h0) provided_input = self.engine.provide_input( h0.get_scope_name(), (self.batch_size, h0_shape[1])) self.engine_inputs[provided_input] = h0 seq_ctx.state_start_data.append(provided_input) seq_ctx.state_size.append(h0_shape[1]) return h0_shape[1:] else: raise Exception( "Index offset that is not 0 or -1 is not supported yet, got {}" .format(idx.get_offset())) else: assert var in feed_dict return feed_dict[var].shape[1:] data_info_cb = Parser.DataInfoCb(self.data_info.get_feed_dict(), has_var_data, get_var_shape) log_level = logging.getLogger().level setup_log(logging.CRITICAL) var_parser = Parser(VarEngine(), elements[0], data_info_cb, self.structure, self.batch_size) for elem in elements: var_seq_ctx = Parser.SequenceCtx([], [], [], [], [], [], [], [], [], set(), set(), True) seq_ctx.output_elem.append( var_parser.deduce(elem, ctx=Parser.get_ctx_with( self.default_ctx, sequence_ctx=var_seq_ctx, ))) for ov in var_seq_ctx.output_var: logging.debug( "Found {} as output var, adding to RNN output".format(ov)) seq_ctx.output_var.append(ov) assert len(seq_ctx.output_var) == 0 or len( seq_ctx.state_var ) > 0, "Deducer failed to find any sequence related elements to calculate" for v, size in zip(seq_ctx.state_var, seq_ctx.state_size): seq = v.get_seq() seq_parts = seq.get_parts() next_idx = v.get_idx() + 1 assert next_idx in seq_parts, "Need to define generation process for sequence {} (define {}[{}])".format( seq, seq, next_idx) output_state = seq[next_idx] if output_state in seq_ctx.input_var: seq_ctx.input_variables.remove(output_state) seq_ctx.output_state_var.append(output_state) self.structure[output_state] = size setup_log(log_level) return seq_ctx
import logging from vilab.log import setup_log from vilab.api import * from vilab.util import * from vilab.deduce import deduce, maximize, Monitor from vilab.env import Env from vilab.datasets import load_mnist_realval from vilab.engines.print_engine import PrintEngine from vilab.parser import Parser setup_log(logging.DEBUG) x, y, h = Sequence("x"), Sequence("y"), Sequence("h") t = Index("t") Function.configure( weight_factor = 0.1 ) f = Function("f") y[t] == f(x[t], h[t-1]) h[t] == f(y[t]) cost = - Summation(SquaredLoss(y[t], x[t])) ###########
def _deduce(self, element, ctx, engine): if self._verbose: logging.debug( "Deducing element: \n elem: {},\n ctx: \n\t{}".format( element, "\n\t".join([ "{} -> {}".format(k, v) for k, v in ctx._asdict().iteritems() ]))) cached = engine.get_cached((element, ctx.density_view)) if not cached is None: logging.debug("Engine: Cache hit for {}: {}".format( element, cached)) return cached else: if self._verbose: logging.debug( "Can't find in the cache: \n elem: {},\n ctx: \n\t{}". format( element, "\n\t".join([ "{} -> {}".format(k, v) for k, v in ctx._asdict().iteritems() ]))) logging.debug("Deducing element `{}`".format(element)) self._level += 1 if logging.getLogger().level == logging.DEBUG: setup_log(logging.DEBUG, ident_level=self._level) callback = None strong_type_callbacks = [ v for k, v in self._callbacks.iteritems() if type(element) == k ] inherit_type_callbacks = [ v for k, v in self._callbacks.iteritems() if isinstance(element, k) ] if len(strong_type_callbacks) > 0: assert len(strong_type_callbacks) == 1 callback = strong_type_callbacks[0] elif len(inherit_type_callbacks) > 0: assert len(inherit_type_callbacks) == 1 callback = inherit_type_callbacks[0] else: callback = self._default_callback result = callback(element, ctx, engine) self._level -= 1 if logging.getLogger().level == logging.DEBUG: setup_log(logging.DEBUG, ident_level=self._level) logging.debug("Done: {}".format(element)) if not isinstance(element, Variable): engine.cache((element, ctx.density_view), result) return result