Esempio n. 1
0
    def test_LocalGroupDB(self):
        lg_db = LocalGroupDB()

        lg_db.register("a", TestOpt(), 1)

        assert "a" in lg_db.__position__

        with pytest.raises(TypeError, match=r"`position` must be.*"):
            lg_db.register("b", TestOpt(), position=object())
Esempio n. 2
0
        import aesara.printing

        print("PrintCurrentFunctionGraph:", self.header)
        aesara.printing.debugprint(fgraph.outputs)


optdb = SequenceDB()
optdb.register("merge1", MergeOptimizer(), 0, "fast_run", "fast_compile",
               "merge")

# After scan1 opt at 0.5 and before ShapeOpt at 1
# This should only remove nodes.
# The opt should not do anything that need shape inference.
# New nodes that don't have infer_shape need that the original node
# also don't have infer_shape
local_useless = LocalGroupDB(apply_all_opts=True, profile=True)
optdb.register(
    "useless",
    TopoDB(local_useless, failure_callback=NavigatorOptimizer.warn_inplace),
    0.6,
    "fast_run",
    "fast_compile",
)

optdb.register("merge1.1", MergeOptimizer(), 0.65, "fast_run", "fast_compile",
               "merge")

# rearranges elemwise expressions
optdb.register(
    "canonicalize",
    EquilibriumDB(ignore_newtrees=False),
Esempio n. 3
0
gpu_cut_copies = EquilibriumDB()

# Not used for an EquilibriumOptimizer. It has the "tracks" that we need for GraphToGPUDB.
gpu_optimizer2 = EquilibriumDB()

gpu_seqopt = SequenceDB()

# do not add 'fast_run' to these two as this would always enable gpuarray mode
optdb.register(
    "gpuarray_opt",
    gpu_seqopt,
    optdb.__position__.get("add_destroy_handler", 49.5) - 1,
    "gpuarray",
)

pool_db = LocalGroupDB()
pool_db2 = LocalGroupDB(local_opt=GraphToGPULocalOptGroup)
pool_db2.__name__ = "pool_db2"

matrix_ops_db = LocalGroupDB()
matrix_ops_db2 = LocalGroupDB(local_opt=GraphToGPULocalOptGroup)
matrix_ops_db2.__name__ = "matrix_ops_db2"

abstract_batch_norm_db = LocalGroupDB()
abstract_batch_norm_db2 = LocalGroupDB(local_opt=GraphToGPULocalOptGroup)
abstract_batch_norm_db2.__name__ = "abstract_batch_norm_db2"

abstract_batch_norm_groupopt = LocalGroupDB()
abstract_batch_norm_groupopt.__name__ = "gpuarray_batchnorm_opts"