예제 #1
0
def test_set_getter(clear_args):
    """Can get the value immediately after setting under Sweep.set.
    **Limitations**: currently only work inside the each function,
    and the set context.
    """

    class G(PrefixProto):
        seed = 10
        postfix = "seed-"

    with Sweep(G).product as sweep:
        G.seed = [10, 20, 30]

    with Sweep(G):
        G.seed = 110
        assert G.seed == 110, "should be able to use the value"
        assert G.postfix == "seed-", "should be able to get original value"

    @sweep.each
    def each(G):
        G.postfix = f"G.seed-({G.seed})"
        G.some_value = "random-string"
        assert G.some_value == "random-string", "should be able to get the updated value right away"

    all = list(sweep)

    assert all[0]['G.postfix'] == "G.seed-(10)"
    assert all[1]['G.postfix'] == "G.seed-(20)"
    assert all[2]['G.postfix'] == "G.seed-(30)"
    assert all[0]['G.some_value'] == "random-string"
예제 #2
0
def test_multiple_configs(clear_args):
    """When using a prefix key, attribute keys that do not exists
    gets written anyways."""

    class G(ParamsProto, cli_parse=False, prefix=True):
        a = 5

    class DEBUG(ParamsProto, prefix=True):
        b = 10

    with Sweep(G, DEBUG) as sweep:
        with sweep.zip:
            G.a = range(100)
            DEBUG.b = range(100)

    for i, all_deps in enumerate(sweep):
        assert G.a == i
        assert DEBUG.b == i

    # only one still works
    with Sweep(G, DEBUG) as sweep:
        with sweep.set:
            DEBUG.nonexist_gets_written = "hey"
        with sweep.product:
            DEBUG.b = range(100)

    # note: does not support "_" (underscore) prefix.
    for i, deps in enumerate(sweep):
        assert deps == {"DEBUG.b": i, "DEBUG.nonexist_gets_written": "hey"}
        assert DEBUG.b == i
        assert DEBUG.nonexist_gets_written == "hey"
예제 #3
0
def test_multiple_configs_no_prefix(clear_args):
    class G(ParamsProto, cli_parse=False):
        a = 5

    class DEBUG(ParamsProto):
        b = 10

    with Sweep(G, DEBUG) as sweep:
        with sweep.zip:
            G.a = range(100)
            DEBUG.b = range(100)

    for i, all_deps in enumerate(sweep):
        assert G.a == i
        assert DEBUG.b == i

    # only one still works
    with Sweep(G, DEBUG) as sweep:
        with sweep.set:
            DEBUG.does_not_exist = "hey"
        with sweep.product:
            DEBUG.b = range(100)

    # note: does not support "_" (underscore) prefix.
    for i, deps in enumerate(sweep):
        assert deps == {"b": i, "does_not_exist": "hey"}
        assert DEBUG.b == i
        assert not hasattr(DEBUG, 'does_not_exist')
예제 #4
0
def test_load(clear_args):
    class P(PrefixProto):
        config_1 = False

    class G(ParamsProto):
        config_2 = False

    sweep = Sweep(P, G).load([{'P.config_1': True}, {'config_2': True}])
    sweep = Sweep(P, G).load([{'P.config_2': True}], strict=False)
    sweep = Sweep(P, G).load([{'does_not_exist': True}], strict=False)

    with pytest.raises(KeyError):
        sweep = Sweep(P, G).load([{'P.config_2': True}])

    with pytest.raises(KeyError):
        sweep = Sweep(P, G).load([{'does_not_exist': True}])
예제 #5
0
def test_product_zip(clear_args):
    class G(PrefixProto):
        start_seed = 10
        discern_flag = True
        env_name = "dm_lab"
        batch_size = 5
        replcas_hint = ['yo', 'hey']

    with Sweep(G) as sweep:
        with sweep.product:
            G.start_seed = range(2)
            G.discern_flag = [True, False]

            with sweep.zip:
                G.env_name = ['small_env', 'large_env']
                G.batch_size = [10, 50]

    all = [*sweep]
    assert len(all) == 2 * 2 * 2
    assert all == [{'G.batch_size': 10, 'G.discern_flag': True, 'G.env_name': 'small_env', 'G.start_seed': 0},
                   {'G.batch_size': 50, 'G.discern_flag': True, 'G.env_name': 'large_env', 'G.start_seed': 0},
                   {'G.batch_size': 10, 'G.discern_flag': False, 'G.env_name': 'small_env', 'G.start_seed': 0},
                   {'G.batch_size': 50, 'G.discern_flag': False, 'G.env_name': 'large_env', 'G.start_seed': 0},
                   {'G.batch_size': 10, 'G.discern_flag': True, 'G.env_name': 'small_env', 'G.start_seed': 1},
                   {'G.batch_size': 50, 'G.discern_flag': True, 'G.env_name': 'large_env', 'G.start_seed': 1},
                   {'G.batch_size': 10, 'G.discern_flag': False, 'G.env_name': 'small_env', 'G.start_seed': 1},
                   {'G.batch_size': 50, 'G.discern_flag': False, 'G.env_name': 'large_env', 'G.start_seed': 1}]
예제 #6
0
def test_incrementation(clear_args):
    """The Sweep resets the configuration at each step,
    to make sure that local overrides do not propagate
    to the next step. This also means that you can not
    imperatively mutate the value step-by-step, such
    as incrementing a counter.

    There are a few patterns for accomplishing this.
    """
    from params_proto.neo_proto import Accumulant

    class G(ParamsProto):
        seed = None
        static_counter = 10
        dynamic_accumulant = Accumulant(10 - 1)

        @classmethod
        def __init__(cls, ):
            cls.static_counter += 1
            cls.dynamic_counter = getattr(cls, "dynamic_counter", -1) + 1
            cls.dynamic_accumulant += 1

    with Sweep(G) as sweep:
        with sweep.product:
            G.seed = [i for i in range(10)]

    for deps in sweep:
        G()
        assert G.static_counter == 11
        assert G.dynamic_counter == G.seed
        assert G.dynamic_accumulant == 10 + G.seed
예제 #7
0
def test_chaining_with_shared_root_set(clear_args):
    # usage
    class G(ParamsProto):
        root_set = False
        level_name = "dmlab"
        some_prefix = f"{level_name}/1"
        start_seed = 10

    with Sweep(G) as sweep:
        G.root_set = True

        with sweep.chain:
            with sweep.set:
                G.level_name = "gotham"
                G.some_prefix = f"gotham/1"

                with sweep.product:
                    G.start_seed = range(15)

            with sweep.set:
                G.level_name = "dmlab"
                G.some_prefix = f"dmlab/1"

                with sweep.product:
                    G.start_seed = range(15)

    all = [*sweep]
    assert len(all) == 30
예제 #8
0
def test_negative_subscription(clear_args):
    class G(ParamsProto):
        start_seed = 10

    with Sweep(G) as sweep:
        with sweep.product:
            G.start_seed = list(range(100))

    # using sweep as a sliced generator.
    for a, b in zip(list(sweep[-10:-5]), [{"start_seed": i} for i in range(90, 95)]):
        assert a == b
예제 #9
0
def test_subscription(clear_args):
    class G(ParamsProto):
        start_seed = 10

    with Sweep(G) as sweep:
        with sweep.product:
            G.start_seed = list(range(100))

    # using sweep as a sliced generator.
    assert list(sweep[:5]) == [{"start_seed": i} for i in range(5)]
    assert list(sweep[10:20:3]) == [{"start_seed": i} for i in range(10, 20, 3)]
    assert list(sweep[30]) == [{"start_seed": 30}]
    assert list(sweep[1:]) == [{"start_seed": i} for i in range(1, 100)]
예제 #10
0
def test_each(clear_args):
    """Can register a function to be ran for each configuration. Useful for
    setting values that dynamically depend on others."""

    class G(PrefixProto):
        seed = 10
        postfix = "seed-"

    with Sweep(G).product as sweep:
        G.seed = [10, 20, 30]

    with Sweep(G):
        G.seed = 110

    @sweep.each
    def each(G):
        G.postfix = f"G.seed-({G.seed})"

    all = list(sweep)

    assert all[0]['G.postfix'] == "G.seed-(10)"
    assert all[1]['G.postfix'] == "G.seed-(20)"
    assert all[2]['G.postfix'] == "G.seed-(30)"
예제 #11
0
def test_set(clear_args):
    class G(PrefixProto):
        start_seed = 10
        discern_flag = True
        env_name = "dm_lab"
        batch_size = 5
        replicas_hint = ['yo', 'hey']

    with Sweep(G) as sweep:
        G.start_seed = 20
        G.discern_flag = False

    for override in sweep:
        assert G.discern_flag is False
        assert override == {'G.discern_flag': False, 'G.start_seed': 20}
예제 #12
0
def test_zip(clear_args):
    # usage
    class G(ParamsProto):
        start_seed = 10
        discern_flag = True
        env_name = "dm_lab"
        batch_size = 5
        replcas_hint = ['yo', 'hey']

    with Sweep(G) as sweep:
        with sweep.zip:
            G.env_name = ['small_env', 'large_env']
            G.batch_size = [10, 50]

    all = [*sweep]
    assert all == [{'batch_size': 10, 'env_name': 'small_env'},
                   {'batch_size': 50, 'env_name': 'large_env'}]
예제 #13
0
def test_set_advanced(clear_args):
    # usage
    class G(ParamsProto):
        start_seed = 10
        discern_flag = True
        env_name = "dm_lab"
        batch_size = 5
        replicas_hint = ['yo', 'hey']

    with Sweep(G) as sweep:
        G.replicas_hint = 12

        with sweep.zip:
            G.env_name = ['small', 'large']
            G.batch_size = [10, 50]

    all = [*sweep]
    assert len(all) == 2
예제 #14
0
def test_product(clear_args):
    # usage
    class G(ParamsProto):
        start_seed = 10
        discern_flag = True
        env_name = "dm_lab"
        batch_size = 5
        replicas_hint = ['yo', 'hey']

    with Sweep(G) as sweep:
        with sweep.product:
            G.start_seed = range(2)
            G.discern_flag = [True, False]

    all = [*sweep]
    assert all == [{'discern_flag': True, 'start_seed': 0},
                   {'discern_flag': False, 'start_seed': 0},
                   {'discern_flag': True, 'start_seed': 1},
                   {'discern_flag': False, 'start_seed': 1}]
예제 #15
0
def test_set_advanced_2(clear_args):
    # usage
    class G(PrefixProto):
        null_attribute = True
        start_seed = 10
        discern_flag = True
        env_name = "dm_lab"
        batch_size = 5
        replicas_hint = ['yo', 'hey']

    with Sweep(G) as sweep:
        G.null_attribute = None
        G.replicas_hint = 12

        with sweep.product:
            G.start_seed = range(15)

            with sweep.zip:
                G.env_name = ['small', 'large', 'xl']
                G.batch_size = [10, 50, 100]

    all = [*sweep]
    assert len(all) == 45
예제 #16
0
def test_jagged(clear_args):
    """the point of this test is to make sure different config with different keys
    always rewrite from the original."""

    # usage
    class G(ParamsProto):
        config_1 = False
        config_2 = False

    with Sweep(G) as sweep:

        with sweep.chain:
            with sweep.set:
                G.config_1 = 10
            with sweep.set:
                G.config_2 = 20

    for i, deps in enumerate(sweep):
        if i == 0:
            assert G.config_1 == 10
            assert G.config_2 is False
        if i == 1:
            assert G.config_1 is False
            assert G.config_2 == 20
예제 #17
0
    ms = [
        # "visiongpu54",
        # "improbable005",
        # "improbable006",
        # "improbable007",
        # "improbable008",
        # "improbable009",
        # "improbable010",
    ]

    if 'pydevd' in sys.modules:
        jaynes.config("local")
    else:
        jaynes.config("supercloud", )

    with Sweep(Args) as sweep, sweep.product:
        with sweep.zip:
            Args.domain = ['walker', 'cartpole', 'ball_in_cup', 'finger']
            Args.task = ['walk', 'swingup', 'catch', 'spin']
        Args.algo = ['pad', 'soda', 'curl', 'rad']
        Args.seed = [100, 300, 400]

    for i, args in enumerate(sweep):
        # jaynes.config("vision", launch=dict(ip=ms[i]))
        RUN.job_postfix = f"{Args.domain}-{Args.task}/{Args.algo}/{Args.seed}"
        thunk = instr(train, args)
        logger.log_text("""
                        keys:
                        - Args.domain
                        - Args.task
                        - Args.algo
예제 #18
0
if __name__ == '__main__':
    import jaynes, sys
    from ml_logger import logger
    from params_proto.neo_hyper import Sweep
    from dmc_gen.train import train
    from dmc_gen.config import Args
    from dmc_gen_analysis import instr, RUN

    with Sweep(Args) as sweep:

        Args.load_checkpoint = "/geyang/dmc_gen/01_baselines/train/ball_in_cup-catch/rad/100"

        Args.domain = 'ball_in_cup'
        Args.task = 'catch'

        Args.algo = 'rad'
        Args.seed = 100

    for i, args in enumerate(sweep):
        if 'pydevd' in sys.modules:
            jaynes.config('local')
        else:
            # jaynes.config("vision", launch=dict(ip=f"visiongpu{ms[i]:02d}"))
            jaynes.config("supercloud", )

        RUN.job_postfix = f"{Args.domain}-{Args.task}/{Args.algo}/{Args.seed}"

        thunk = instr(train, args, _job_counter=False)

        logger.log_text("""
                    keys: