コード例 #1
0
from default_args import get_training_args, DomainAndProblemConfiguration
from train import train_wrapper

_CONFIGURATION = DomainAndProblemConfiguration(
    base_directory="../benchmarks/ferry",
    domain_pddl="ferry.pddl",
    # {2, 3, 4 locations} x {1, 2, 3} cars = 9 problems
    problem_pddls=[
        "train/ferry-l2-c1.pddl",
        "train/ferry-l2-c2.pddl",
        "train/ferry-l2-c3.pddl",
        "train/ferry-l3-c1.pddl",
        "train/ferry-l3-c2.pddl",
        "train/ferry-l3-c3.pddl",
        "train/ferry-l4-c1.pddl",
        "train/ferry-l4-c2.pddl",
        "train/ferry-l4-c3.pddl",
    ],
)
assert len(_CONFIGURATION.problems) == 9

if __name__ == "__main__":
    train_wrapper(
        args=get_training_args(
            configurations=[_CONFIGURATION],
            # 3 minutes
            max_training_time=3 * 60,
            num_folds=5,
        ),
        domain_name="ferry")
コード例 #2
0
ファイル: blocksworld.py プロジェクト: yaaig-ufrgs/STRIPS-HGN
        "blocks4/task03.pddl",
        "blocks4/task04.pddl",
        "blocks4/task05.pddl",
        "blocks4/task06.pddl",
        "blocks4/task07.pddl",
        "blocks4/task08.pddl",
        "blocks4/task09.pddl",
        "blocks4/task10.pddl",
        "blocks5/task01.pddl",
        "blocks5/task02.pddl",
        "blocks5/task03.pddl",
        "blocks5/task04.pddl",
        "blocks5/task05.pddl",
        "blocks5/task06.pddl",
        "blocks5/task07.pddl",
        "blocks5/task08.pddl",
        "blocks5/task09.pddl",
        "blocks5/task10.pddl",
    ],
)
assert len(_CONFIGURATION.problems) == 30

if __name__ == "__main__":
    train_wrapper(
        args=get_training_args(
            configurations=[_CONFIGURATION],
            # 10 minutes
            max_training_time=10 * 60,
        ),
        domain_name="blocks")
コード例 #3
0
_ZENOTRAVEL_CONFIGURATION = DomainAndProblemConfiguration(
    base_directory="../benchmarks/zenotravel",
    domain_pddl="domain.pddl",
    # 5 x {2, 3 cities} = 10 Zenotravel problems
    problem_pddls=[
        "train/zenotravel-cities2-planes1-people3-8798.pddl",
        "train/zenotravel-cities2-planes2-people3-9145.pddl",
        "train/zenotravel-cities2-planes3-people3-3417.pddl",
        "train/zenotravel-cities2-planes4-people2-4892.pddl",
        "train/zenotravel-cities2-planes4-people4-6874.pddl",
        "train/zenotravel-cities3-planes1-people3-4791.pddl",
        "train/zenotravel-cities3-planes2-people3-8752.pddl",
        "train/zenotravel-cities3-planes2-people5-7306.pddl",
        "train/zenotravel-cities3-planes3-people3-1826.pddl",
        "train/zenotravel-cities3-planes3-people5-4582.pddl",
    ],
)
assert len(_ZENOTRAVEL_CONFIGURATION.problems) == 10

if __name__ == "__main__":
    train_wrapper(
        args=get_training_args(
            configurations=[
                _GRIPPER_CONFIGURATION,
                _ZENOTRAVEL_CONFIGURATION,
            ],
            # 10 minutes
            max_training_time=10 * 60,
        ),
        domain_name="indepgz")
コード例 #4
0
        "train/zenotravel-cities3-planes3-people3-1826.pddl",
        "train/zenotravel-cities3-planes3-people5-4582.pddl",
    ],
)
assert len(_ZENOTRAVEL_CONFIGURATION.problems) == 10

_GRIPPER_CONFIGURATION = DomainAndProblemConfiguration(
    base_directory="../benchmarks/gripper",
    domain_pddl="domain.pddl",
    # First 3 gripper probs
    problem_pddls=[
        "problems/gripper-n1.pddl",
        "problems/gripper-n2.pddl",
        "problems/gripper-n3.pddl",
    ],
)
assert len(_GRIPPER_CONFIGURATION.problems) == 3

if __name__ == "__main__":
    train_wrapper(
        args=get_training_args(
            configurations=[
                _BLOCKSWORLD_CONFIGURATION,
                _ZENOTRAVEL_CONFIGURATION,
                _GRIPPER_CONFIGURATION,
            ],
            # 15 minutes
            max_training_time=15 * 60,
        ),
        domain_name="multi")
コード例 #5
0
ファイル: gripper.py プロジェクト: williamshen-nz/STRIPS-HGN
from default_args import get_training_args, DomainAndProblemConfiguration
from train import train_wrapper

_CONFIGURATION = DomainAndProblemConfiguration(
    base_directory="../benchmarks/gripper",
    domain_pddl="domain.pddl",
    # {1, 2, 3 balls} = 3 problems
    problem_pddls=[
        "problems/gripper-n1.pddl",
        "problems/gripper-n2.pddl",
        "problems/gripper-n3.pddl",
    ],
)
assert len(_CONFIGURATION.problems) == 3

if __name__ == "__main__":
    train_wrapper(args=get_training_args(
        configurations=[_CONFIGURATION],
        # 90 seconds
        max_training_time=90,
        num_bins=3,
    ))