def remove_constant_copies(operators): """ Change Copies with constant input to Resets. If a Copy has no dependencies, or just one Reset() dependency, then we can change it to an op that just directly sets the output signal to the Copy input value. Parameters ---------- operators : list of `~nengo.builder.Operator` Operators in the model Returns ------- new_operators : list of `~nengo.builder.Operator` Modified list of operators """ sets, incs, _, updates = signal_io_dicts(operators) new_operators = [] for op in operators: if isinstance(op, Copy): src = op.src # check if the input is the output of a Node (in which case the # value might change, so we should never get rid of this op). # checking the name of the signal seems a bit fragile, but I can't # think of a better solution if src.name.startswith("<Node"): new_operators.append(op) continue pred = sets[src.base] + incs[src.base] if (len(pred) == 0 and not op.src.trainable and len(updates[src.base]) == 0): # no predecessors means that the src is constant. but we also # need to keep the bias signal if it is trainable (since # changing it to a reset op would make it not trainable). # we also need to check if anything is updating src (which # wouldn't be in the predecessors). val = op.src.initial_value[op.src_slice] elif len(pred) == 1 and type(pred[0]) == Reset: # if the only predecessor is a Reset, we can just use that # set value val = pred[0].value try: new_operators.remove(pred[0]) except ValueError: operators.remove(pred[0]) else: new_operators.append(op) continue new_op = Reset(op.dst if op.dst_slice is None else op.dst[op.dst_slice]) # note: we need to set the value separately to bypass the float() # casting in Reset new_op.value = val if op.inc: new_op.incs.extend(new_op.sets) new_op.sets = [] new_op.__class__ = op_builders.ResetInc new_operators.append(new_op) else: new_operators.append(op) return new_operators
def test_remove_identity_muls(Op): # check that identity input signals get removed As = [1.0, np.diag(np.ones(3)) if Op == DotInc else np.ones(3)] for A in As: x = dummies.Signal(shape=(1,) if isinstance(A, float) else A.shape[:1]) y = dummies.Signal(shape=(1,) if isinstance(A, float) else A.shape[:1]) a = Signal(A) a.trainable = False operators = [Op(a, x, y)] new_operators = remove_identity_muls(operators) assert len(new_operators) == 1 new_op = new_operators[0] assert isinstance(new_op, Copy) assert new_op.src is x assert new_op.dst is y assert new_op.inc # check that identity x gets removed for elementwiseinc if Op == ElementwiseInc: a = dummies.Signal() x = dummies.Signal(initial_value=1) y = dummies.Signal() operators = [Op(a, x, y)] new_operators = remove_identity_muls(operators) assert len(operators) == 1 new_op = new_operators[0] assert isinstance(new_op, Copy) assert new_op.src is a assert new_op.dst is y assert new_op.inc # check that reset inputs get removed for A in As: x = dummies.Signal(shape=(1,) if isinstance(A, float) else A.shape[:1]) y = dummies.Signal(shape=(1,) if isinstance(A, float) else A.shape[:1]) a = dummies.Signal(shape=(1,) if isinstance(A, float) else A.shape) r = Reset(a) r.value = A operators = [Op(a, x, y), r] new_operators = remove_identity_muls(operators) assert len(new_operators) == 2 assert new_operators[1:] == operators[1:] new_op = new_operators[0] assert isinstance(new_op, Copy) assert new_op.src is x assert new_op.dst is y assert new_op.inc # check that non-identity inputs don't get removed a = Signal(np.ones((3, 3))) a.trainable = False operators = [Op(a, dummies.Signal(shape=(3,)), dummies.Signal(shape=(3,)))] new_operators = remove_identity_muls(operators) assert new_operators == operators # check that node inputs don't get removed x = dummies.Signal(label="<Node lorem ipsum") operators = [Op(x, dummies.Signal(), dummies.Signal())] new_operators = remove_identity_muls(operators) assert new_operators == operators # check that identity inputs + trainable don't get removed x = Signal(1.0) x.trainable = True operators = [Op(x, dummies.Signal(), dummies.Signal())] new_operators = remove_identity_muls(operators) assert new_operators == operators # check that updated input doesn't get removed x = dummies.Signal() operators = [Op(x, dummies.Signal(), dummies.Signal()), dummies.Op(updates=[x])] new_operators = remove_identity_muls(operators) assert new_operators == operators # check that inc'd input doesn't get removed x = dummies.Signal() operators = [Op(x, dummies.Signal(), dummies.Signal()), dummies.Op(incs=[x])] new_operators = remove_identity_muls(operators) assert new_operators == operators # check that set'd input doesn't get removed x = dummies.Signal() operators = [Op(x, dummies.Signal(), dummies.Signal()), dummies.Op(sets=[x])] new_operators = remove_identity_muls(operators) assert new_operators == operators