def test_order_signals_views(): base = dummies.Signal(shape=(6, ), label="base") sig = dummies.Signal(shape=(7, ), label="sig") sig2 = dummies.Signal(shape=(7, ), label="sig2") views = [ dummies.Signal(shape=(1, ), base_shape=(5, ), offset=1 + i, label="view_%d" % i) for i in range(5) ] for v in views: v.base = base plan = [ ( dummies.Op(reads=[base]), dummies.Op(reads=[views[1]]), dummies.Op(reads=[views[0]]), dummies.Op(reads=[sig2]), ), (dummies.Op(reads=[base]), dummies.Op(reads=[sig])), tuple(dummies.Op(reads=[views[i]]) for i in range(4, 2, -1)), (dummies.Op(reads=[views[4]]), dummies.Op(reads=[sig])), ] sigs, new_plan = order_signals(plan) assert contiguous([base, sig, sig2], sigs) assert ordered(new_plan[0], sigs) assert ordered(new_plan[1], sigs) assert ordered(new_plan[2], sigs) assert ordered(new_plan[3], sigs)
def test_planner_unmergeable(planner): # check that non-mergeable operators aren't merged input0 = dummies.Signal() operators = [Copy(input0, dummies.Signal(dtype=np.float32)), Copy(input0, dummies.Signal(dtype=np.int32))] plan = planner(operators) assert len(plan) == 2 assert type(plan[0][0]) == Copy assert len(plan[0]) == 1 assert type(plan[1][0]) == Copy assert len(plan[1]) == 1
def test_planner_mergeable(planner): # check that mergeable operators are merged input0 = dummies.Signal() input1 = dummies.Signal() output0 = dummies.Signal() output1 = dummies.Signal() operators = [Copy(input0, output0, inc=True), Copy(input1, output1, inc=True)] plan = planner(operators) assert len(plan) == 1 assert type(plan[0][0]) == Copy assert len(plan[0]) == 2
def test_noop_order_signals(): inputs = [dummies.Signal(label="a"), dummies.Signal(label="b"), dummies.Signal(label="c", base_shape=(2,))] plan = [(dummies.Op(reads=[x]),) for x in inputs] sigs, new_plan = noop_order_signals(plan) assert all(x == y for x, y in zip(plan, new_plan)) assert len(sigs) == 3 sigs.remove(inputs[0]) sigs.remove(inputs[1]) assert sigs[0].name == "c.base"
def test_planner_size(): # check that operators are selected according to number of available ops input0 = dummies.Signal() operators = [Copy(input0, dummies.Signal(), inc=True) for _ in range(2)] operators += [Copy(input0, dummies.Signal())] operators += [DotInc(input0, dummies.Signal(), dummies.Signal()) for _ in range(3)] plan = greedy_planner(operators) assert len(plan) == 3 assert len(plan[0]) == 3 assert len(plan[1]) == 2 assert len(plan[2]) == 1
def test_planner_chain(planner): # test a chain a = dummies.Signal(label="a") b = dummies.Signal(label="b") c = dummies.Signal(label="c") d = dummies.Signal(label="d") operators = [Copy(a, b, inc=True) for _ in range(3)] operators += [SimPyFunc(c, lambda x: x, None, b)] operators += [Copy(c, d, inc=True) for _ in range(2)] plan = planner(operators) assert len(plan) == 3 assert len(plan[0]) == 3 assert len(plan[1]) == 1 assert len(plan[2]) == 2
def test_planner_cycle(planner): inputs = [dummies.Signal() for _ in range(3)] operators = [Copy(inputs[0], inputs[1]), Copy(inputs[1], inputs[2]), Copy(inputs[2], inputs[0])] with pytest.raises(BuildError): planner(operators)
def test_noop_planner(): inputs = [dummies.Signal() for _ in range(3)] operators = [Copy(inputs[1], inputs[2]), Copy(inputs[0], inputs[1])] plan = noop_planner(operators) assert len(plan) == len(operators) assert plan[0] == (operators[1],) assert plan[1] == (operators[0],)
def test_create_signals_views(): sigs = [dummies.Signal(shape=(2, 2), base_shape=(4,)), dummies.Signal(shape=(2, 2), base_shape=(4,))] sigs += [sigs[0].base, sigs[1].base] plan = [tuple(dummies.Op(reads=[x]) for x in sigs)] graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs[2:]) assert list(graph.base_arrays_init.values())[0][0].shape == (8, 10) assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key assert graph.signals[sigs[1]].key == graph.signals[sigs[2]].key assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key assert np.all(graph.signals[sigs[0]].indices == (0, 1, 2, 3)) assert np.all(graph.signals[sigs[1]].indices == (4, 5, 6, 7)) assert np.all(graph.signals[sigs[0]].indices == graph.signals[sigs[2]].indices) assert np.all(graph.signals[sigs[1]].indices == graph.signals[sigs[3]].indices)
def test_order_signals_lowpass(): # test that lowpass outputs are ordered as reads inputs = [dummies.Signal(label=str(i)) for i in range(10)] time = dummies.Signal() plan = [ tuple(SimProcess(Lowpass(0.1), inputs[i], inputs[i + 1], time, mode="update") for i in range(0, 4, 2)), tuple(SimProcess(Lowpass(0.1), inputs[i], inputs[i + 1], time, mode="update") for i in range(5, 9, 2))] sigs, new_plan = order_signals(plan) assert contiguous(inputs[1:5:2], sigs) assert contiguous(inputs[6:10:2], sigs) assert ordered(new_plan[0], sigs, block=1) assert ordered(new_plan[0], sigs, block=2) assert ordered(new_plan[1], sigs, block=1) assert ordered(new_plan[1], sigs, block=2)
def test_order_signals_noreads(): # test with ops that don't have any reads inputs = [dummies.Signal(label=str(i)) for i in range(10)] plan = [ tuple(dummies.Op(reads=[inputs[i]]) for i in range(5)), tuple(dummies.Op(sets=[inputs[5 + i]]) for i in range(5)), ] sigs, new_plan = order_signals(plan) assert contiguous(inputs[:5], sigs) assert ordered(new_plan[0], sigs)
def test_create_signals_views(): sigs = [ dummies.Signal(shape=(2, 2), base_shape=(4, )), dummies.Signal(shape=(2, 2), base_shape=(4, )), ] sigs += [sigs[0].base, sigs[1].base] plan = [tuple(dummies.Op(reads=[x]) for x in sigs)] graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs[2:]) assert list(graph.base_arrays_init["non_trainable"].values())[0][1] == [ (10, 4), (10, 4), ] assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key assert graph.signals[sigs[1]].key == graph.signals[sigs[2]].key assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key assert graph.signals[sigs[0]].slices == ((0, 4), ) assert graph.signals[sigs[1]].slices == ((4, 8), ) assert graph.signals[sigs[0]].slices == graph.signals[sigs[2]].slices assert graph.signals[sigs[1]].slices == graph.signals[sigs[3]].slices
def test_order_signals_disjoint(): # disjoint reads inputs = [dummies.Signal(label=str(i)) for i in range(10)] plan = [ tuple(dummies.Op(reads=[inputs[i]]) for i in range(5)), tuple(dummies.Op(reads=[inputs[5 + i]]) for i in range(5))] sigs, new_plan = order_signals(plan) assert contiguous(inputs[:5], sigs) assert contiguous(inputs[5:], sigs) assert ordered(new_plan[0], sigs) assert ordered(new_plan[1], sigs)
def test_order_signals_partial3(): inputs = [dummies.Signal(label=str(i)) for i in range(10)] plan = [ tuple(dummies.Op(reads=[inputs[i]]) for i in [0, 1, 2, 3]), tuple(dummies.Op(reads=[inputs[i]]) for i in [0, 4, 7]), tuple(dummies.Op(reads=[inputs[i]]) for i in [5, 6, 7])] sigs, new_plan = order_signals(plan) assert contiguous(inputs[:4], sigs) assert contiguous([inputs[0], inputs[4], inputs[7]], sigs) assert contiguous(inputs[5:8], sigs) assert ordered(new_plan[0], sigs) assert ordered(new_plan[1], sigs) assert ordered(new_plan[2], sigs)
def test_order_signals_partial(): # partially overlapping reads # two overlapping sets (A, A/B, B) inputs = [dummies.Signal(label=str(i)) for i in range(10)] plan = [ tuple(dummies.Op(reads=[inputs[i]]) for i in range(4)), tuple(dummies.Op(reads=[inputs[2 + i]]) for i in range(4))] sigs, new_plan = order_signals(plan) assert contiguous(inputs[:4], sigs) assert contiguous(inputs[2:6], sigs) assert ordered(new_plan[0], sigs) assert ordered(new_plan[1], sigs)
def test_order_signals_subset(): # ordering in which one read block is fully nested within another # (A, A/B) inputs = [dummies.Signal(label=str(i)) for i in range(10)] plan = [ tuple(dummies.Op(reads=[inputs[i]]) for i in range(10)), tuple(dummies.Op(reads=[inputs[4 - i]]) for i in range(5)), ] sigs, new_plan = order_signals(plan) assert contiguous(inputs[:5], sigs) assert contiguous(inputs[:10], sigs) assert ordered(new_plan[0], sigs) assert ordered(new_plan[1], sigs)
def test_create_signals_partition(): # check that signals are partitioned based on plan sigs = [ dummies.Signal(), dummies.Signal(), dummies.Signal(), dummies.Signal() ] plan = [ tuple(dummies.Op(reads=[x]) for x in sigs[:2]), tuple(dummies.Op(reads=[x]) for x in sigs[2:]), ] graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs) assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key # check that signals are partitioned for different read blocks plan = [tuple(dummies.Op(reads=[sigs[i], sigs[2 + i]]) for i in range(2))] graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs) assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key # check that signals are partitioned for different sig types plan = [ tuple( dummies.Op(reads=[sigs[i]], sets=[sigs[2 + i]]) for i in range(2)) ] graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs) assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key # check that resets are ignored sigs = [ dummies.Signal(), dummies.Signal(), dummies.Signal(), dummies.Signal() ] plan = [tuple(Reset(x) for x in sigs)] graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs) assert len(graph.base_arrays_init["non_trainable"]) == 4
def test_sparsedotinc_mergeable(): assert mergeable( SparseDotInc(dummies.Signal(sparse=True), dummies.Signal(), dummies.Signal()), [ SparseDotInc(dummies.Signal(sparse=True), dummies.Signal(), dummies.Signal()) ], )
def test_order_signals_duplicate_read_blocks(): # test that order_signal prioritizes read blocks that are duplicated in # multiple op groups inputs = [dummies.Signal(label=str(i)) for i in range(10)] plan = [ tuple(dummies.Op(reads=[inputs[i], inputs[5 + i]]) for i in range(3)), tuple(dummies.Op(reads=[inputs[i], inputs[5 + i]]) for i in range(3)), tuple(dummies.Op(reads=[inputs[5 + i], inputs[4 - i]]) for i in range(5))] sigs, new_plan = order_signals(plan) assert ordered(new_plan[0], sigs) assert ordered(new_plan[1], sigs) assert (ordered(new_plan[2], sigs, block=0) or ordered(new_plan[2], sigs, block=1)) assert not ordered(new_plan[2], sigs)
def test_order_signals_multiread_complex(): # signal sorting with operators that read from multiple signals # (overlapping) # (C, B/C, A) (where A and B are from the same op) inputs = [dummies.Signal(label=str(i)) for i in range(10)] plan = [ tuple(dummies.Op(reads=[inputs[i], inputs[5 + i]]) for i in range(3)), tuple(dummies.Op(reads=[inputs[i + 5]]) for i in range(5)), ] sigs, new_plan = order_signals(plan) assert contiguous(inputs[:3], sigs) assert contiguous(inputs[5:], sigs) assert contiguous(inputs[5:8], sigs) assert ordered(new_plan[0], sigs) assert ordered(new_plan[1], sigs)
def test_order_signals_partial2(): # more complex partial overlap # (A, A/B, B/C, C) inputs = [dummies.Signal(label=str(i)) for i in range(10)] plan = [ tuple(dummies.Op(reads=[inputs[i]]) for i in range(5)), tuple(dummies.Op(reads=[inputs[2 + i]]) for i in range(4)), tuple(dummies.Op(reads=[inputs[5 + i]]) for i in range(3)), ] sigs, new_plan = order_signals(plan) assert contiguous(inputs[:5], sigs) assert contiguous(inputs[5:8], sigs) assert contiguous(inputs[2:6], sigs) assert ordered(new_plan[0], sigs) assert ordered(new_plan[1], sigs) assert ordered(new_plan[2], sigs)
def test_order_signals_neuron_states(): # test with neuron states (should be treated as reads) inputs = [dummies.Signal(label=str(i)) for i in range(10)] plan = [ tuple(SimNeurons(None, inputs[0], inputs[1], states=[x]) for x in inputs[2::2]), tuple(SimNeurons(None, inputs[0], inputs[1], states=[x]) for x in inputs[3::2])] sigs, new_plan = order_signals(plan) assert contiguous(inputs[2::2], sigs) assert contiguous(inputs[3::2], sigs) # note: block=0 is just a single signal, so it's always "ordered" assert ordered(new_plan[0], sigs, block=1) assert ordered(new_plan[1], sigs, block=1)
def test_order_signals_partial_unsatisfiable(): # this one will be unsatisfied, because after A it will pick A/B (because # B is the next biggest block). technically this could be satisfied if # we picked A/C next, but is there a way we could know that? # (A, A/B, A/C, B) inputs = [dummies.Signal(label=str(i)) for i in range(10)] plan = [ tuple(dummies.Op(reads=[inputs[i]]) for i in range(7)), tuple(dummies.Op(reads=[inputs[5 + i]]) for i in range(5)), tuple(dummies.Op(reads=[inputs[i]]) for i in range(3)), ] sigs, new_plan = order_signals(plan) assert contiguous(inputs[:7], sigs) assert not contiguous(inputs[5:], sigs) assert contiguous(inputs[:3], sigs) assert ordered(new_plan[0], sigs) assert ordered(new_plan[2], sigs)
def test_order_signals_duplicates(): # test where read blocks contain duplicate signals inputs = [dummies.Signal(label=str(i)) for i in range(4)] plan = [ tuple(dummies.Op(reads=[inputs[0]]) for _ in range(2)) + (dummies.Op(reads=[inputs[2]]),), tuple(dummies.Op(reads=[inputs[1]]) for _ in range(2)) + (dummies.Op(reads=[inputs[3]]),) ] sigs, new_plan = order_signals(plan) assert contiguous([inputs[0], inputs[2]], sigs) assert contiguous([inputs[1], inputs[3]], sigs) # note: not possible for these to be in increasing order, since they # contain duplicates assert not ordered(new_plan[0], sigs) assert not ordered(new_plan[1], sigs)
def test_order_signals_multiread_complex2(): # (B, B/A, A, A/C, C) (where A and B are from the same op) inputs = [dummies.Signal(label=str(i)) for i in range(10)] plan = [ tuple(dummies.Op(reads=[inputs[2 + i], inputs[i]]) for i in range(4)), tuple(dummies.Op(reads=[inputs[5 + i]]) for i in range(3)), ] sigs, new_plan = order_signals(plan) assert contiguous(inputs[5:8], sigs) assert ordered(new_plan[1], sigs) # TODO: technically it is always possible to order both blocks properly, # but it requires you to know which of the two equally sized blocks should # have priority, and I'm not sure there's a way to determine that. assert contiguous(inputs[:4], sigs) or contiguous(inputs[2:6], sigs) assert ordered(new_plan[0], sigs, block=0) or ordered( new_plan[0], sigs, block=1)
def test_order_signals_multiread_unsatisfiable(): # unsatisfiable order for block C (conflicts with A, which gets prioritized # because it is in a larger group) # (A, A/C, B, B/D) inputs = [dummies.Signal(label=str(i)) for i in range(10)] plan = [ tuple(dummies.Op(reads=[inputs[i], inputs[5 + i]]) for i in range(5)), tuple(dummies.Op(reads=[inputs[1 - i], inputs[5 + i]]) for i in range(2)), ] sigs, new_plan = order_signals(plan) assert contiguous(inputs[:5], sigs) assert contiguous(inputs[5:], sigs) assert contiguous(inputs[:2], sigs) assert contiguous(inputs[5:7], sigs) assert ordered(new_plan[0], sigs) assert (ordered(new_plan[1], sigs, block=0) or ordered(new_plan[1], sigs, block=1)) assert not ordered(new_plan[1], sigs)
def test_build(trainable, rng): sigs = [ dummies.Signal(shape=(2, 1), dtype="float32", initial_value=0, trainable=trainable), dummies.Signal( shape=(3, 1), dtype="float32", initial_value=np.zeros((3, 1)), trainable=trainable, ), dummies.Signal(shape=(4, 1), dtype="float32", initial_value=1, trainable=trainable), dummies.Signal( shape=(5, 1), dtype="float32", initial_value=np.ones((5, 1)), trainable=trainable, ), dummies.Signal( shape=(6, 1), dtype="float32", initial_value=rng.uniform(size=(6, 1)), trainable=trainable, ), dummies.Signal( shape=(7, 1), dtype="float32", initial_value=rng.uniform(size=(7, 1)), trainable=trainable, ), ] plan = [ tuple(dummies.Op(reads=[x]) for x in sigs[:2]), tuple(dummies.Op(reads=[x]) for x in sigs[2:4]), tuple(dummies.Op(reads=[x]) for x in sigs[4:]), ] graph = dummies.TensorGraph(plan=plan, dtype="float32", minibatch_size=16) graph.create_signals(sigs) graph.build() if trainable: assert len(graph.trainable_weights) == 3 assert len(graph.non_trainable_weights) == 0 else: assert len(graph.trainable_weights) == 0 assert len(graph.non_trainable_weights) == 3 init0 = graph.weights[0].numpy() assert init0.shape == (5, 1) if trainable else (16, 5, 1) assert np.allclose(init0, 0) init1 = graph.weights[1].numpy() assert init1.shape == (9, 1) if trainable else (16, 9, 1) assert np.allclose(init1, 1) init2 = graph.weights[2].numpy() if trainable: assert init2.shape == (13, 1) assert np.allclose(init2[:6], sigs[4].initial_value) assert np.allclose(init2[6:], sigs[5].initial_value) else: assert init2.shape == (16, 13, 1) assert np.allclose(init2[:, :6], sigs[4].initial_value) assert np.allclose(init2[:, 6:], sigs[5].initial_value)
def test_create_signals(): # check that floats/ints get split into different arrays sigs = [ dummies.Signal(dtype=np.float32), dummies.Signal(dtype=np.float32), dummies.Signal(dtype=np.int32), dummies.Signal(dtype=np.int32), ] plan = [tuple(dummies.Op(reads=[x]) for x in sigs)] graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs) assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key # check that floats all get converted to same precision and combined sigs = [ dummies.Signal(dtype=np.float32), dummies.Signal(dtype=np.float32), dummies.Signal(dtype=np.float64), dummies.Signal(dtype=np.float64), ] plan = [tuple(dummies.Op(reads=[x]) for x in sigs)] graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs) assert np.all([graph.signals[x].dtype == "float32" for x in sigs]) assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key assert graph.signals[sigs[1]].key == graph.signals[sigs[2]].key assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key # check that ints all get converted to same precision and combined sigs = [ dummies.Signal(dtype=np.int32), dummies.Signal(dtype=np.int32), dummies.Signal(dtype=np.int64), dummies.Signal(dtype=np.int64), ] plan = [tuple(dummies.Op(reads=[x]) for x in sigs)] graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs) assert np.all([graph.signals[x].dtype == "int32" for x in sigs]) assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key assert graph.signals[sigs[1]].key == graph.signals[sigs[2]].key assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key # check that different shapes go in different groups sigs = [ dummies.Signal(shape=(10, )), dummies.Signal(shape=(5, )), dummies.Signal(shape=(10, 1)), dummies.Signal(shape=(5, 1)), ] plan = [tuple(dummies.Op(reads=[x]) for x in sigs)] graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs) assert graph.base_arrays_init["non_trainable"][graph.signals[ sigs[0]].key][1] == [ (10, 10), (10, 5), ] assert graph.base_arrays_init["non_trainable"][graph.signals[ sigs[2]].key][1] == [ (10, 10, 1), (10, 5, 1), ] assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key # check trainable sigs = [ dummies.Signal(trainable=True), dummies.Signal(trainable=True), dummies.Signal(trainable=False), dummies.Signal(trainable=False), ] plan = [tuple(dummies.Op(reads=[x]) for x in sigs)] graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs) assert graph.base_arrays_init["trainable"][graph.signals[ sigs[0]].key][1] == [ (1, ), (1, ), ] assert graph.base_arrays_init["non_trainable"][graph.signals[ sigs[2]].key][1] == [ (10, 1), (10, 1), ] assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key # check that scalars get upsized sigs = [dummies.Signal(shape=()), dummies.Signal(shape=(4, ))] plan = [tuple(dummies.Op(reads=[x]) for x in sigs)] graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs) assert list(graph.base_arrays_init["non_trainable"].values())[0][1] == [ (10, 1), (10, 4), ] # check that boolean signals are handled correctly sigs = [dummies.Signal(dtype=np.bool, shape=())] plan = [(dummies.Op(reads=sigs), )] graph = dummies.TensorGraph(plan, tf.float32, 1) graph.create_signals(sigs) assert list( graph.base_arrays_init["non_trainable"].values())[0][2] == "bool"
def test_remove_identity_muls(Op): # check that identity input signals get removed As = [1.0, np.diag(np.ones(3)) if Op == DotInc else np.ones(3)] for A in As: x = dummies.Signal(shape=(1,) if isinstance(A, float) else A.shape[:1]) y = dummies.Signal(shape=(1,) if isinstance(A, float) else A.shape[:1]) a = Signal(A) a.trainable = False operators = [Op(a, x, y)] new_operators = remove_identity_muls(operators) assert len(new_operators) == 1 new_op = new_operators[0] assert isinstance(new_op, Copy) assert new_op.src is x assert new_op.dst is y assert new_op.inc # check that identity x gets removed for elementwiseinc if Op == ElementwiseInc: a = dummies.Signal() x = dummies.Signal(initial_value=1) y = dummies.Signal() operators = [Op(a, x, y)] new_operators = remove_identity_muls(operators) assert len(operators) == 1 new_op = new_operators[0] assert isinstance(new_op, Copy) assert new_op.src is a assert new_op.dst is y assert new_op.inc # check that reset inputs get removed for A in As: x = dummies.Signal(shape=(1,) if isinstance(A, float) else A.shape[:1]) y = dummies.Signal(shape=(1,) if isinstance(A, float) else A.shape[:1]) a = dummies.Signal(shape=(1,) if isinstance(A, float) else A.shape) r = Reset(a) r.value = A operators = [Op(a, x, y), r] new_operators = remove_identity_muls(operators) assert len(new_operators) == 2 assert new_operators[1:] == operators[1:] new_op = new_operators[0] assert isinstance(new_op, Copy) assert new_op.src is x assert new_op.dst is y assert new_op.inc # check that non-identity inputs don't get removed a = Signal(np.ones((3, 3))) a.trainable = False operators = [Op(a, dummies.Signal(shape=(3,)), dummies.Signal(shape=(3,)))] new_operators = remove_identity_muls(operators) assert new_operators == operators # check that node inputs don't get removed x = dummies.Signal(label="<Node lorem ipsum") operators = [Op(x, dummies.Signal(), dummies.Signal())] new_operators = remove_identity_muls(operators) assert new_operators == operators # check that identity inputs + trainable don't get removed x = Signal(1.0) x.trainable = True operators = [Op(x, dummies.Signal(), dummies.Signal())] new_operators = remove_identity_muls(operators) assert new_operators == operators # check that updated input doesn't get removed x = dummies.Signal() operators = [Op(x, dummies.Signal(), dummies.Signal()), dummies.Op(updates=[x])] new_operators = remove_identity_muls(operators) assert new_operators == operators # check that inc'd input doesn't get removed x = dummies.Signal() operators = [Op(x, dummies.Signal(), dummies.Signal()), dummies.Op(incs=[x])] new_operators = remove_identity_muls(operators) assert new_operators == operators # check that set'd input doesn't get removed x = dummies.Signal() operators = [Op(x, dummies.Signal(), dummies.Signal()), dummies.Op(sets=[x])] new_operators = remove_identity_muls(operators) assert new_operators == operators
def test_remove_constant_copies(): # check that Copy with no inputs gets turned into Reset x = dummies.Signal() operators = [Copy(dummies.Signal(), x)] new_operators = remove_constant_copies(operators) assert len(new_operators) == 1 assert isinstance(new_operators[0], Reset) assert new_operators[0].dst is x assert new_operators[0].value == 0 # check that Copy with Node input doesn't get changed x = dummies.Signal(label="<Node lorem ipsum") operators = [Copy(x, dummies.Signal())] new_operators = remove_constant_copies(operators) assert new_operators == operators # check that Copy with trainable input doesn't get changed x = dummies.Signal() x.trainable = True operators = [Copy(x, dummies.Signal())] new_operators = remove_constant_copies(operators) assert new_operators == operators # check Copy with updated input doesn't get changed x = dummies.Signal() operators = [Copy(x, dummies.Signal()), dummies.Op(updates=[x])] new_operators = remove_constant_copies(operators) assert new_operators == operators # check Copy with inc'd input doesn't get changed x = dummies.Signal() operators = [Copy(x, dummies.Signal()), dummies.Op(incs=[x])] new_operators = remove_constant_copies(operators) assert new_operators == operators # check Copy with set input doesn't get changed x = dummies.Signal() operators = [Copy(x, dummies.Signal()), dummies.Op(sets=[x])] new_operators = remove_constant_copies(operators) assert new_operators == operators # check Copy with read input/output does get changed x = dummies.Signal() y = dummies.Signal() operators = [Copy(x, y), dummies.Op(reads=[x]), dummies.Op(reads=[y])] new_operators = remove_constant_copies(operators) assert len(new_operators) == 3 assert new_operators[1:] == operators[1:] assert isinstance(new_operators[0], Reset) assert new_operators[0].dst is y assert new_operators[0].value == 0 # check Copy with Reset input does get changed x = dummies.Signal() y = dummies.Signal() operators = [Copy(x, y), Reset(x, 2)] new_operators = remove_constant_copies(operators) assert len(new_operators) == 1 assert isinstance(new_operators[0], Reset) assert new_operators[0].dst is y assert new_operators[0].value == 2 # check that slicing is respected x = dummies.Signal() y = Signal(initial_value=[0, 0]) operators = [Copy(x, y, dst_slice=slice(1, 2)), Reset(x, 2)] new_operators = remove_constant_copies(operators) assert len(new_operators) == 1 assert isinstance(new_operators[0], Reset) assert new_operators[0].dst.shape == (1,) assert new_operators[0].dst.is_view assert new_operators[0].dst.elemoffset == 1 assert new_operators[0].dst.base is y assert new_operators[0].value == 2 # check that CopyInc gets turned into ResetInc x = dummies.Signal() y = dummies.Signal() operators = [Copy(x, y, inc=True), Reset(x, 2)] new_operators = remove_constant_copies(operators) assert len(new_operators) == 1 assert isinstance(new_operators[0], op_builders.ResetInc) assert new_operators[0].dst is y assert new_operators[0].value == 2 assert len(new_operators[0].incs) == 1 assert len(new_operators[0].sets) == 0