Beispiel #1
0
    def __init__(self, inputs, outputs, **kwargs):
        if not isinstance(outputs, list):
            raise TypeError("outputs must be list", outputs)
        for i in inputs + outputs:
            if not isinstance(i, gof.Variable):
                raise TypeError("inputs and outputs must be Variable instances", i)
        if "updates" in kwargs or "givens" in kwargs:
            raise TypeError("updates and givens are not allowed in kwargs")

        # To support correctly shared variables the inner fct should
        # not see them. Otherwise their is problem with the gradient.
        self.shared_inputs = [var for var in gof.graph.inputs(outputs) if isinstance(var, SharedVariable)]
        shared_vars = [var.type() for var in self.shared_inputs]
        new = rebuild_collect_shared(
            outputs,
            inputs=inputs + shared_vars,
            replace=dict(izip(self.shared_inputs, shared_vars)),
            copy_inputs_over=False,
        )
        (new_inputs, new_outputs, [clone_d, update_d, update_expr, shared_inputs]) = new
        assert len(new_inputs) == len(inputs) + len(self.shared_inputs)
        assert len(new_outputs) == len(outputs)
        assert not update_d
        assert not update_expr
        assert not shared_inputs

        self.new_inputs = new_inputs
        self.new_outputs = new_outputs
        self.inputs = inputs
        self.outputs = outputs
        self.kwargs = kwargs
        self.input_types = [input.type for input in inputs]
        self.output_types = [output.type for output in outputs]
Beispiel #2
0
    def __init__(self, inputs, outputs, **kwargs):
        if not isinstance(outputs, list):
            raise TypeError('outputs must be list', outputs)
        for i in inputs + outputs:
            if not isinstance(i, gof.Variable):
                raise TypeError(
                    'inputs and outputs must be Variable instances', i)
        if 'updates' in kwargs:
            raise TypeError('updates are not allowed in kwargs')

        # To support correctly shared variables the inner fct should
        # not see them. Otherwise their is problem with the gradient.
        self.shared_inputs = [var for var in gof.graph.inputs(outputs)
                              if isinstance(var, SharedVariable)]
        shared_vars = [var.type() for var in self.shared_inputs]
        new = rebuild_collect_shared(outputs, inputs=inputs + shared_vars,
                                     replace=dict(zip(self.shared_inputs,
                                                      shared_vars)),
                                     copy_inputs_over=False)
        (new_inputs, new_outputs,
         [clone_d, update_d, update_expr, shared_inputs]) = new
        assert len(new_inputs) == len(inputs) + len(self.shared_inputs)
        assert len(new_outputs) == len(outputs)
        assert not update_d
        assert not update_expr
        assert not shared_inputs

        self.new_inputs = new_inputs
        self.new_outputs = new_outputs
        self.inputs = inputs
        self.outputs = outputs
        self.kwargs = kwargs
        self.input_types = [input.type for input in inputs]
        self.output_types = [output.type for output in outputs]
Beispiel #3
0
    def __init__(
        self, inputs, outputs,
        inline=False,
        lop_overrides='default',
        grad_overrides='default',
        rop_overrides='default',
        name=None, **kwargs
    ):
        if not isinstance(outputs, list):
            raise TypeError('outputs must be list, got %s' % type(outputs))
        for i in inputs + outputs:
            if not isinstance(i, gof.Variable):
                raise TypeError(
                    'inputs and outputs must be Variable instances', i)
        if 'updates' in kwargs or 'givens' in kwargs:
            raise TypeError('updates and givens are not allowed here')
        self.is_inline = inline
        # To correctly support shared variables the inner fct should
        # not see them. Otherwise there is a problem with the gradient.
        self.shared_inputs = [var for var in gof.graph.inputs(outputs)
                              if isinstance(var, SharedVariable)]
        shared_vars = [var.type() for var in self.shared_inputs]

        new = rebuild_collect_shared(outputs, inputs=inputs + shared_vars,
                                     replace=dict(izip(
                                         self.shared_inputs, shared_vars)),
                                     copy_inputs_over=False)
        (local_inputs, local_outputs,
         [clone_d, update_d, update_expr, shared_inputs]) = new
        assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
        assert len(local_outputs) == len(outputs)
        assert not update_d
        assert not update_expr
        assert not shared_inputs

        self.local_inputs = local_inputs
        self.local_outputs = local_outputs
        self.inputs = inputs
        self.outputs = outputs
        self.kwargs = kwargs
        self.input_types = [inp.type for inp in inputs]
        self.output_types = [out.type for out in outputs]
        if lop_overrides != 'default':
            if grad_overrides != 'default':
                raise ValueError('lop_overrides and grad_overrides are mutually exclusive')
            else:
                self.set_lop_overrides(lop_overrides)
                self._lop_type = 'lop'
        elif grad_overrides != 'default':
            self.set_lop_overrides(grad_overrides)
            self._lop_type = 'grad'
        else:
            self.set_lop_overrides('default')
            self._lop_type = 'lop'
        self.set_rop_overrides(rop_overrides)

        if name is not None:
            assert isinstance(name, str), 'name must be None or string object'
        self.name = name
Beispiel #4
0
    def __init__(self,
                 inputs,
                 outputs,
                 inline=False,
                 grad_overrides='default',
                 rop_overrides='default',
                 name=None,
                 **kwargs):
        if not isinstance(outputs, list):
            raise TypeError('outputs must be list, got %s' % type(outputs))
        for i in inputs + outputs:
            if not isinstance(i, gof.Variable):
                raise TypeError(
                    'inputs and outputs must be Variable instances', i)
        if 'updates' in kwargs or 'givens' in kwargs:
            raise TypeError('updates and givens are not allowed here')
        self.is_inline = inline
        # To correctly support shared variables the inner fct should
        # not see them. Otherwise there is a problem with the gradient.
        self.shared_inputs = [
            var for var in gof.graph.inputs(outputs)
            if isinstance(var, SharedVariable)
        ]
        shared_vars = [var.type() for var in self.shared_inputs]

        new = rebuild_collect_shared(outputs,
                                     inputs=inputs + shared_vars,
                                     replace=dict(
                                         izip(self.shared_inputs,
                                              shared_vars)),
                                     copy_inputs_over=False)
        (local_inputs, local_outputs,
         [clone_d, update_d, update_expr, shared_inputs]) = new
        assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
        assert len(local_outputs) == len(outputs)
        assert not update_d
        assert not update_expr
        assert not shared_inputs

        self.local_inputs = local_inputs
        self.local_outputs = local_outputs
        self.inputs = inputs
        self.outputs = outputs
        self.kwargs = kwargs
        self.input_types = [inp.type for inp in inputs]
        self.output_types = [out.type for out in outputs]
        self.set_grad_overrides(grad_overrides)
        self.set_rop_overrides(rop_overrides)

        if name is not None:
            assert isinstance(name, str), 'name must be None or string object'
        self.name = name
Beispiel #5
0
    def __init__(
        self,
        inputs,
        outputs,
        inline=False,
        lop_overrides="default",
        grad_overrides="default",
        rop_overrides="default",
        connection_pattern=None,
        name=None,
        **kwargs,
    ):
        if not isinstance(outputs, list):
            raise TypeError("outputs must be list, got %s" % type(outputs))
        for i in inputs + outputs:
            if not isinstance(i, gof.Variable):
                raise TypeError(
                    "inputs and outputs must be Variable instances", i)
        if "updates" in kwargs or "givens" in kwargs:
            raise TypeError("updates and givens are not allowed here")
        self.is_inline = inline
        # To correctly support shared variables the inner fct should
        # not see them. Otherwise there is a problem with the gradient.
        self.shared_inputs = [
            var for var in gof.graph.inputs(outputs)
            if isinstance(var, SharedVariable)
        ]
        shared_vars = [var.type() for var in self.shared_inputs]

        new = rebuild_collect_shared(
            outputs,
            inputs=inputs + shared_vars,
            replace=dict(zip(self.shared_inputs, shared_vars)),
            copy_inputs_over=False,
        )
        (
            local_inputs,
            local_outputs,
            [clone_d, update_d, update_expr, shared_inputs],
        ) = new
        assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
        assert len(local_outputs) == len(outputs)
        assert not update_d
        assert not update_expr
        assert not shared_inputs

        self.local_inputs = local_inputs
        self.local_outputs = local_outputs
        self.inputs = inputs
        self.outputs = outputs
        self.kwargs = kwargs
        self.input_types = [inp.type for inp in inputs]
        self.output_types = [out.type for out in outputs]
        if lop_overrides != "default":
            if grad_overrides != "default":
                raise ValueError(
                    "lop_overrides and grad_overrides are mutually exclusive")
            else:
                self.set_lop_overrides(lop_overrides)
                self._lop_type = "lop"
        elif grad_overrides != "default":
            self.set_lop_overrides(grad_overrides)
            self._lop_type = "grad"
        else:
            self.set_lop_overrides("default")
            self._lop_type = "lop"
        self.set_rop_overrides(rop_overrides)

        self._connection_pattern = connection_pattern

        if name is not None:
            assert isinstance(name, str), "name must be None or string object"
        self.name = name