def test_sparsedotinc_mergeable(): assert mergeable( SparseDotInc(dummies.Signal(sparse=True), dummies.Signal(), dummies.Signal()), [ SparseDotInc(dummies.Signal(sparse=True), dummies.Signal(), dummies.Signal()) ], )
def build_sparse(model, transform, sig_in, decoders=None, encoders=None, rng=np.random): """Build a `.Sparse` transform object.""" if decoders is not None: raise BuildError( "Applying a sparse transform to a decoded connection is not supported" ) # Shouldn't be possible for encoders to be non-None, since that only # occurs for a connection solver with weights=True, and those can only # be applied to decoded connections (which are disallowed above) assert encoders is None # Add output signal weighted = Signal(shape=transform.size_out, name="%s.weighted" % transform) model.add_op(Reset(weighted)) weights = transform.sample(rng=rng) assert weights.ndim == 2 # Add operator for applying weights weight_sig = Signal(weights, name="%s.weights" % transform, readonly=True) model.add_op( SparseDotInc(weight_sig, sig_in, weighted, tag="%s.apply_weights" % transform) ) return weighted, weight_sig
def test_sparsedotinc_builderror(): A = Signal(np.ones(2)) X = Signal(np.ones(2)) Y = Signal(np.ones(2)) with pytest.raises(BuildError, match="must be a sparse Signal"): SparseDotInc(A, X, Y)
def test_remove_reset_incs(): # elementwiseinc converted to elementwiseset x = dummies.Signal() operators = [ Reset(x), ElementwiseInc(dummies.Signal(), dummies.Signal(), x) ] new_operators = remove_reset_incs(operators) assert len(new_operators) == 1 assert isinstance(new_operators[0], op_builders.ElementwiseSet) assert new_operators[0].Y is x assert new_operators[0].incs == [] assert new_operators[0].sets == [x] # dotinc converted to dotset x = dummies.Signal() operators = [Reset(x), DotInc(dummies.Signal(), dummies.Signal(), x)] new_operators = remove_reset_incs(operators) assert len(new_operators) == 1 assert isinstance(new_operators[0], op_builders.DotSet) assert new_operators[0].Y is x # copy inc converted to copy set x = dummies.Signal() operators = [Reset(x), Copy(dummies.Signal(), x, inc=True)] new_operators = remove_reset_incs(operators) assert len(new_operators) == 1 assert not new_operators[0].inc assert new_operators[0].dst is x # simprocess inc converted to simprocess set x = dummies.Signal() operators = [ Reset(x), SimProcess(None, dummies.Signal(), x, dummies.Signal(), mode="inc"), ] new_operators = remove_reset_incs(operators) assert len(new_operators) == 1 assert new_operators[0].mode == "set" assert new_operators[0].output is x # convinc converted to convset x = dummies.Signal() operators = [ Reset(x), ConvInc(dummies.Signal(), dummies.Signal(), x, None) ] new_operators = remove_reset_incs(operators) assert len(new_operators) == 1 assert isinstance(new_operators[0], transform_builders.ConvSet) assert new_operators[0].Y is x # sparsedotinc converted to sparsedotset x = dummies.Signal() operators = [ Reset(x), SparseDotInc(dummies.Signal(sparse=True), dummies.Signal(), x, None), ] new_operators = remove_reset_incs(operators) assert len(new_operators) == 1 assert isinstance(new_operators[0], op_builders.SparseDotSet) assert new_operators[0].Y is x # resetinc converted to reset x = dummies.Signal() operators = [Reset(x), op_builders.ResetInc(x)] operators[1].value = np.ones((2, 3)) new_operators = remove_reset_incs(operators) assert len(new_operators) == 1 assert type(new_operators[0]) == Reset assert np.allclose(new_operators[0].value, 1) assert new_operators[0].dst is x # multiple incs x = dummies.Signal() operators = [ Reset(x), ElementwiseInc(dummies.Signal(), dummies.Signal(), x), ElementwiseInc(dummies.Signal(), dummies.Signal(), x), ] new_operators = remove_reset_incs(operators) assert len(new_operators) == 2 assert isinstance(new_operators[0], op_builders.ElementwiseSet) assert isinstance(new_operators[1], ElementwiseInc) # nonzero reset doesn't get converted x = dummies.Signal() operators = [ Reset(x, value=1), ElementwiseInc(dummies.Signal(), dummies.Signal(), x), ] new_operators = remove_reset_incs(operators) assert operators == new_operators # reset without inc x = dummies.Signal() operators = [ Reset(x), Copy(dummies.Signal(), x, inc=False), ] new_operators = remove_reset_incs(operators) assert operators == new_operators # reset with partial inc x = Signal(shape=(10, )) operators = [ Reset(x), Copy(dummies.Signal(), x[:5], inc=True), ] new_operators = remove_reset_incs(operators) assert operators == new_operators # unknown inc type class NewCopy(Copy): pass x = dummies.Signal() operators = [ Reset(x), NewCopy(dummies.Signal(), x, inc=True), ElementwiseInc(dummies.Signal(), dummies.Signal(), x), ] with pytest.warns(UserWarning, match="Unknown incer type"): new_operators = remove_reset_incs(operators) assert len(new_operators) == 2 # uses the known op (ElementwiseInc) instead of unknown one assert isinstance(new_operators[0], op_builders.ElementwiseSet) assert new_operators[1] is operators[1] operators = [ Reset(x), NewCopy(dummies.Signal(), x, inc=True), ] # no optimization if only unknown incers with pytest.warns(UserWarning, match="Unknown incer type"): new_operators = remove_reset_incs(operators) assert new_operators == operators