def clone( output, replace=None, strict=True, share_inputs=True, copy_inputs=DEPRECATED_ARG ): """ Function that allows replacing subgraphs of a computational graph. It returns a copy of the initial subgraph with the corresponding substitutions. Parameters ---------- output : Theano Variables (or Theano expressions) Theano expression that represents the computational graph. replace : dict Dictionary describing which subgraphs should be replaced by what. share_inputs : bool If True, use the same inputs (and shared variables) as the original graph. If False, clone them. Note that cloned shared variables still use the same underlying storage, so they will always have the same value. copy_inputs Deprecated, use share_inputs. """ if copy_inputs is not DEPRECATED_ARG: warnings.warn( "In `clone()` function, the argument `copy_inputs` has been deprecated and renamed into `share_inputs`" ) assert share_inputs # since we used `copy_inputs` we should have default value for `share_inputs` share_inputs = copy_inputs if isinstance(replace, dict): items = list(replace.items()) elif isinstance(replace, (list, tuple)): items = replace elif replace is None: items = [] else: raise ValueError( ( "replace is neither a dictionary, list, " f"tuple or None ! The value provided is {replace}," f"of type {type(replace)}" ) ) tmp_replace = [(x, x.type()) for x, y in items] new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)] _, _outs, _ = rebuild_collect_shared( output, [], tmp_replace, [], strict, share_inputs ) # TODO Explain why we call it twice ?! _, outs, _ = rebuild_collect_shared( _outs, [], new_replace, [], strict, share_inputs ) return outs
def __init__( self, inputs, outputs, inline=False, lop_overrides="default", grad_overrides="default", rop_overrides="default", connection_pattern=None, name=None, **kwargs, ): if not isinstance(outputs, list): raise TypeError(f"outputs must be list, got {type(outputs)}") for i in inputs + outputs: if not isinstance(i, gof.Variable): raise TypeError( "inputs and outputs must be Variable instances", i) if "updates" in kwargs or "givens" in kwargs: raise TypeError("updates and givens are not allowed here") self.is_inline = inline # To correctly support shared variables the inner fct should # not see them. Otherwise there is a problem with the gradient. self.shared_inputs = [ var for var in gof.graph.inputs(outputs) if isinstance(var, SharedVariable) ] shared_vars = [var.type() for var in self.shared_inputs] new = rebuild_collect_shared( outputs, inputs=inputs + shared_vars, replace=dict(zip(self.shared_inputs, shared_vars)), copy_inputs_over=False, ) ( local_inputs, local_outputs, [clone_d, update_d, update_expr, shared_inputs], ) = new assert len(local_inputs) == len(inputs) + len(self.shared_inputs) assert len(local_outputs) == len(outputs) assert not update_d assert not update_expr assert not shared_inputs self.local_inputs = local_inputs self.local_outputs = local_outputs self.inputs = inputs self.outputs = outputs self.kwargs = kwargs self.input_types = [inp.type for inp in inputs] self.output_types = [out.type for out in outputs] if lop_overrides != "default": if grad_overrides != "default": raise ValueError( "lop_overrides and grad_overrides are mutually exclusive") else: self.set_lop_overrides(lop_overrides) self._lop_type = "lop" elif grad_overrides != "default": self.set_lop_overrides(grad_overrides) self._lop_type = "grad" else: self.set_lop_overrides("default") self._lop_type = "lop" self.set_rop_overrides(rop_overrides) self._connection_pattern = connection_pattern if name is not None: assert isinstance(name, str), "name must be None or string object" self.name = name