Exemplo n.º 1
0
    def get_theano_variables(self, inputs=None, outputs=None):
        """
        Returns a dict containing inputs, outputs and graph corresponding to
        the Theano version of the pyfn.

        This version of the function returns a single vector input.
        """
        inputs = utils.as_seq(inputs, tuple)
        outputs = utils.as_seq(outputs, tuple)

        if inputs:
            sym_inputs = [self.get_symbolic(x) for x in inputs]
        else:
            sym_inputs = self.s_inputs.values()

        if outputs:
            sym_outputs = [self.get_symbolic(x) for x in outputs]
        else:
            sym_outputs = self.s_outputs.values()

        if len(sym_outputs) > 1:
            raise ValueError(
                'VectorArg functions should return a single output.')

        # get symbolic inputs corresponding to shared inputs in s_inputs
        s_memo = OrderedDict()
        sym_args = utils.flat_from_doc(sym_inputs)
        real_args = utils.flat_from_doc(self.all_init_args)

        # create a symbolic vector, then split it up into symbolic input
        # args
        inputs_dtype = self.vector_from_args(self.all_init_args).dtype
        theano_input = tt.vector(name='theta', dtype=inputs_dtype)
        i = 0
        for sa, ra in zip(sym_args, real_args):
            if sa.ndim > 0:
                vector_arg = theano_input[i: i + ra.size].reshape(ra.shape)
            else:
                vector_arg = theano_input[i]
            s_memo[sa] = tt.patternbroadcast(
                vector_arg.astype(str(sa.dtype)),
                broadcastable=sa.broadcastable)
            i += ra.size

        # get new graph, replacing shared inputs with symbolic ones
        graph = theano.gof.graph.clone_get_equiv(
            theano.gof.graph.inputs(sym_outputs),
            sym_outputs,
            memo=s_memo.copy())

        # get symbolic outputs
        theano_outputs = graph[sym_outputs[0]]

        f_in, f_out = self.finalize(theano_input, theano_outputs, graph)

        return f_in, f_out, graph
Exemplo n.º 2
0
    def get_theano_variables(self, inputs=None, outputs=None):
        """
        Returns a dict containing inputs, outputs and graph corresponding to
        the Theano version of the pyfn.
        """
        inputs = utils.as_seq(inputs, tuple)
        sym_inputs = [self.get_symbolic(x) for x in inputs]

        outputs = utils.as_seq(outputs, tuple)
        sym_outputs = [self.get_symbolic(x) for x in outputs]

        # get symbolic inputs corresponding to shared inputs in s_inputs
        s_memo = OrderedDict((arg, arg.type())
                             for arg in utils.flat_from_doc(sym_inputs))
        theano_inputs = tuple(s_memo.values())

        # get new graph, replacing shared inputs with symbolic ones
        graph = theano.gof.graph.clone_get_equiv(
            theano.gof.graph.inputs(sym_outputs),
            sym_outputs,
            memo=s_memo.copy())

        # get symbolic outputs
        theano_outputs = tuple([graph[o] for o in sym_outputs])

        return theano_inputs, theano_outputs, graph
Exemplo n.º 3
0
    def get_theano_fn(self, args, kwargs):
        self.trace(*args, **kwargs)

        fn_inputs, fn_outputs, graph = self.get_theano_variables(
            self.s_inputs, self.s_outputs)

        if np.any([o.ndim != 0 for o in fn_outputs]):
            raise TypeError('HessianVector requires scalar outputs.')

        # get wrt variables. If none were specified, use inputs.
        wrt = utils.as_seq(self.wrt)
        if len(wrt) == 0:
            wrt = [i for i in fn_inputs]
        else:
            wrt = [graph[self.get_symbolic(w)] for w in wrt]

        grads = utils.flat_from_doc([tt.grad(o, wrt=wrt) for o in fn_outputs])

        sym_vecs = tuple(tt.TensorType(dtype=w.dtype,
                                       broadcastable=[False]*w.ndim)()
                         for w in wrt)
        hess_vec = tt.Rop(grads, wrt, sym_vecs)

        if len(hess_vec) == 1:
            hess_vec = hess_vec[0]

        # compile function
        fn = theano.function(inputs=fn_inputs + sym_vecs,
                             outputs=hess_vec,
                             on_unused_input='ignore')

        return fn
Exemplo n.º 4
0
def clean_int_args(*args, **kwargs):
    """
    Given args and kwargs, replaces small integers with numpy int16 objects, to
    allow tracing.
    """
    flatargs = utils.flat_from_doc(args)
    for i, a in enumerate(flatargs):
        if type(a) is int and -5 <= a <= 256:
            flatargs[i] = np.int16(a)
    clean_args = utils.doc_from_flat(args, flatargs)

    flatkwargs = utils.flat_from_doc(kwargs)
    for i, a in enumerate(flatkwargs):
        if type(a) is int and -5 <= a <= 256:
            flatkwargs[i] = np.int16(a)
    clean_kwargs = utils.doc_from_flat(kwargs, flatkwargs)
    return clean_args, clean_kwargs
Exemplo n.º 5
0
    def trace(self, *args, **kwargs):
        """
        Given args and kwargs, call the Python function and get its
        symbolic representation.

        A dictionary of shadowed symbolic variables is maintained:
            self.s_vars   : {id(obj) : sym_var}
                            Contains all symbolic variables traced during
                            function execution, indexed by the id of the
                            corresponding Python object.

        Additionally, self.s_inputs and self.s_outputs are lists of symbolic
        arguments and results, respectively.
        """

        # clean args and kwargs
        c_args, c_kwargs = clean_int_args(*args, **kwargs)

        # call the Context
        results = self.context.call(self.pyfn, c_args, c_kwargs)

        # get a tuple of the symbolic inputs
        # but avoid 'self' and 'cls' bound arguments
        callargs = utils.orderedcallargs(self.pyfn, *c_args, **c_kwargs)
        all_args = utils.flat_from_doc(callargs)
        if (inspect.ismethod(self.pyfn) or
           (len(all_args) > 0 and type(all_args[0]) is type)):
            all_args = all_args[1:]
        self.s_inputs = tuple([self.s_vars[id(a)] for a in all_args])

        # get a tuple of the symbolic outputs
        self.s_outputs = tuple(
            [self.s_vars[id(r)] for r in utils.as_seq(results)])

        # update variable names where possible
        for name, arg in callargs.iteritems():
            if self.s_vars.get(id(arg), None) in self.s_inputs:
                self.s_vars[name] = self.s_vars[id(arg)]

        return results
Exemplo n.º 6
0
    def compile_gradient(self,
                         inputs=None,
                         outputs=None,
                         wrt=None,
                         reduction=None):

        fn_inputs, fn_outputs, graph = self.get_theano_variables(inputs,
                                                                 outputs)
        wrt = utils.as_seq(wrt)

        if reduction in ['sum', 'max', 'mean', 'min', 'prod', 'std', 'var']:
            reduction = getattr(theano.tensor, reduction)
        if callable(reduction):
            if 'numpy' in reduction.__module__:
                reduction = getattr(theano.tensor, reduction.__name__)
            fn_outputs = [reduction(o) for o in fn_outputs]

        if np.any([o.ndim != 0 for o in fn_outputs]):
            raise TypeError('Gradient requires either scalar outputs or a '
                            'reduction that returns a scalar.')

        # get wrt variables. If none were specified, use inputs.
        if len(wrt) == 0:
            wrt = [i for i in fn_inputs]
        else:
            wrt = [graph[self.get_symbolic(w)] for w in wrt]

        grads = utils.flat_from_doc([tt.grad(o, wrt=wrt) for o in fn_outputs])

        if len(grads) == 1:
            grads = grads[0]

        # compile function
        fn = theano.function(inputs=fn_inputs,
                             outputs=grads,
                             on_unused_input='ignore')

        return fn
Exemplo n.º 7
0
    def call(self, args, kwargs):
        if not isinstance(args, tuple):
            raise TypeError('vm.call: args must be tuple', args)
        if not isinstance(kwargs, dict):
            raise TypeError('vm.call: kwargs must be dict', kwargs)

        func = self.func

        if isinstance(func, type) and issubclass(func, BaseException):
            # XXX not shadowing exception creation, because exceptions
            # do not have func_code. Is this OK? can we do better?
            return func(*args, **kwargs)

        func_code = self.func.func_code

        self._myglobals = {}
        self._locals = []

        for name in func_code.co_names:
            #print 'name', name
            try:
                self._myglobals[name] = func.func_globals[name]
            except KeyError:
                try:
                    self._myglobals[name] = __builtin__.__getattribute__(name)
                except AttributeError:
                    #print 'WARNING: name lookup failed', name
                    pass

        # get function arguments
        argspec = inspect.getargspec(func)

        # match function arguments to passed parameters
        callargs = orderedcallargs(func, *args, **kwargs)

        # named args => locals
        self._locals.extend(callargs[arg] for arg in argspec.args)

        # *args => locals
        if argspec.varargs:
            self._locals.append(callargs[argspec.varargs])

        # **kwargs => locals
        if argspec.keywords:
            self._locals.append(callargs[argspec.keywords])

        # other vars => locals
        no_unbound_args = len(func_code.co_varnames) - len(self._locals)
        self._locals.extend([Unassigned] * no_unbound_args)

        # shadow arguments
        for val in flat_from_doc(callargs):
            if id(val) not in self.watcher:
                self.add_shadow(val)

        self.code_iter = itercode(func_code.co_code)
        jmp = None
        while not hasattr(self, 'rval'):
            try:
                i, op, arg = self.code_iter.send(jmp)
            except StopIteration:
                break
            name = opcode.opname[op]
            # method names can't have '+' in them
            name = {'SLICE+0': 'SLICE_PLUS_0',
                    'SLICE+1': 'SLICE_PLUS_1',
                    'SLICE+2': 'SLICE_PLUS_2',
                    'SLICE+3': 'SLICE_PLUS_3',
                    'STORE_SLICE+0': 'STORE_SLICE_PLUS_0',
                    'STORE_SLICE+1': 'STORE_SLICE_PLUS_1',
                    'STORE_SLICE+2': 'STORE_SLICE_PLUS_2',
                    'STORE_SLICE+3': 'STORE_SLICE_PLUS_3',
                    }.get(name, name)
            if self.print_ops:
                print 'OP: ', i, name
            if self.print_stack:
                print self.stack
            try:
                op_method = getattr(self, 'op_' + name)
            except AttributeError:
                raise AttributeError('FrameVM does not have a method defined '
                                     'for \'op_{0}\''.format(name))
            except:
                raise
            jmp = op_method(i, op, arg)

        return self.rval