示例#1
0
def setup(log_dir):
    logger.configure(log_dir, prefix='main_test_script')
    logger.remove('')
    logger.print('hey')
    sleep(1.0)

    print(f"logging to {pathJoin(logger.log_directory, logger.prefix)}")
示例#2
0
def launch(root, seed=None):
    from ml_logger import logger

    logger.configure(root_dir=root,
                     prefix=f"geyang/jaynes-demo/seed-{seed}",
                     register_experiment=True)
    logger.print("this has be ran")
示例#3
0
def launch(**_G):
    import baselines.common.tf_util as U
    import traceback
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = str(1)

    try:
        run_maml(_G=_G)
    except Exception as e:
        tb = traceback.format_exc()
        logger.print(tb)
        logger.print(U.ALREADY_INITIALIZED)
        raise e
示例#4
0
    def thunk(*args, **kwargs):
        import traceback
        from ml_logger import logger

        assert not (args and ARGS), \
            f"can not use position argument at both thunk creation as well as run.\n" \
            f"_args: {args}\n" \
            f"ARGS: {ARGS}\n"

        logger.configure(root_dir=RUN.server,
                         prefix=PREFIX,
                         register_experiment=False,
                         max_workers=10)
        logger.log_params(host=dict(hostname=logger.hostname),
                          run=dict(status="running",
                                   startTime=logger.now(),
                                   job_id=logger.job_id))

        import time
        try:
            _KWARGS = {**KWARGS}
            _KWARGS.update(**kwargs)

            results = fn(*(args or ARGS), **_KWARGS)

            logger.log_line("========== execution is complete ==========")
            logger.log_params(
                run=dict(status="completed", completeTime=logger.now()))
            logger.flush()
            time.sleep(3)
        except Exception as e:
            tb = traceback.format_exc()
            with logger.SyncContext(
            ):  # Make sure uploaded finished before termination.
                logger.print(tb, color="red")
                logger.log_text(tb, filename="traceback.err")
                logger.log_params(
                    run=dict(status="error", exitTime=logger.now()))
                logger.flush()
            time.sleep(3)
            raise e

        return results
示例#5
0
def sample_trajs(seed, env_id=Args.env_id):
    from ge_world import IS_PATCHED
    assert IS_PATCHED, "required for these envs."

    np.random.seed(seed)
    env = gym.make(env_id)
    env.reset()

    trajs = []
    for i in range(Args.n_rollout):
        obs = env.reset()
        path = [obs['x']]
        trajs.append(path)
        for t in range(Args.n_timesteps - 1):
            obs, reward, done, info = env.step(np.random.randint(low=0, high=7))
            path.append(obs['x'])
    trajs = np.array(trajs)
    from ml_logger import logger
    logger.print(f'seed {seed} has finished sampling.', color="green")
    return trajs
示例#6
0
def launch(model=None, test_fn=None, **_G):
    import matplotlib

    matplotlib.use('Agg')

    G.update(_G)

    import numpy as np
    np.random.seed(G.seed)
    t.manual_seed(G.seed)
    t.cuda.manual_seed(G.seed)

    logger.configure(log_directory=G.log_dir, prefix=G.log_prefix)

    logger.log_params(G=vars(G))

    model = model or Model(**vars(G))
    logger.print(str(model))

    from playground.maml.maml_torch.tasks import Sine
    maml_supervised(model, Sine, test_fn=test_fn, **vars(G))
示例#7
0
def torch_upload():
    from ml_logger import logger
    import numpy as np

    logger.configure(root_dir="http://54.71.92.65:9080", prefix="geyang/ml_logger-debug/test-1",
                     register_experiment=True)
    logger.log_params(args={})

    with logger.Sync():
        import os
        import torch
        from pycurl import Curl
        from tempfile import NamedTemporaryFile

        logger.remove('upload/example.pt')

        with NamedTemporaryFile(delete=True) as f:
            torch.save(np.ones([10_000_000]), f)
            # torch.save(np.ones([1000_000]), f)
            logger.print(f.name)

            c = Curl()
            c.setopt(c.URL, logger.root_dir)
            # proxy = os.environ.get('HTTP_PROXY')
            # c.setopt(c.PROXY, proxy)
            # logger.print('proxy:', proxy)
            c.setopt(c.TIMEOUT, 100000)
            c.setopt(c.HTTPPOST, [
                ('file', (
                    c.FORM_FILE, f.name,
                    c.FORM_FILENAME, logger.prefix + '/upload/example.pt',
                    c.FORM_CONTENTTYPE, 'plain/text',
                )),
            ])
            c.perform()
            c.close()

        logger.print('done')


        # logger.remove(".")
        # a = np.ones([1, 1, 100_000_000 // 4])
        # logger.print(f"the size of the tensor is {a.size}")
        # data = dict(key="ok", large=a)
        # logger.torch_save(data, f"save/data-{logger.now('%H.%M.%S')}.pkl")
    logger.print('done')
示例#8
0
    # plt.legend(loc="upper left", bbox_to_anchor=(0.45, 0.8), framealpha=1, frameon=False, fontsize=12)
    plt.tight_layout()
    logger.savefig("../figures/maze_plans.png", dpi=300)
    plt.show()
    plt.close()

    # colors = ['#49b8ff', '#ff7575', '#66c56c', '#f4b247']
    fig = plt.figure(figsize=(3.8, 3), dpi=300)
    plt.title('Planning Cost')
    plt.bar(cache.cost.keys(), cache.cost.values(), color="gray", width=0.8)
    plt.ylim(0, max(cache.cost.values()) * 1.2)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.tight_layout()
    logger.savefig("../figures/maze_cost.png", dpi=300)
    plt.ylabel('# of distance lookup')
    plt.show()

    fig = plt.figure(figsize=(3.8, 3), dpi=300)
    plt.title('Plan Length')
    plt.bar(cache.len.keys(), cache.len.values(), color="gray", width=0.8)
    plt.ylim(0, max(cache.len.values()) * 1.2)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.tight_layout()
    logger.savefig("../figures/maze_length.png", dpi=300)
    plt.ylabel('Path Length')
    plt.show()

    logger.print('done', color="green")
示例#9
0
def instr(fn, *ARGS, __file=False, __silent=False, **KWARGS):
    """
    thunk for configuring the logger. The reason why this is not a decorator is

    :param fn: function to be called
    :param *ARGS: position arguments for the call
    :param __file__: console mode, by-pass file related logging
    :param __silent: do not print
    :param **KWARGS: keyword arguments for the call
    :return: a thunk that can be called without parameters
    """
    from ml_logger import logger

    if __file:
        caller_script = pJoin(os.getcwd(), __file)
    else:
        launch_module = inspect.getmodule(inspect.stack()[1][0])
        __file = launch_module.__file__
        caller_script = abspath(__file)

    # note: for scripts in the `plan2vec` module this also works -- b/c we truncate fixed depth.
    script_path = logger.truncate(caller_script,
                                  depth=len(__file__.split('/')) - 1)
    file_stem = logger.stem(script_path)
    file_name = basename(file_stem)

    RUN(file_name=file_name, file_stem=file_stem, now=logger.now())

    PREFIX = RUN.PREFIX

    # todo: there should be a better way to log these.
    # todo: we shouldn't need to log to the same directory, and the directory for the run shouldn't be fixed.
    logger.configure(
        root_dir=RUN.server,
        prefix=PREFIX,
        asynchronous=False,  # use sync logger
        max_workers=4,
        register_experiment=False)
    if RUN.restart:
        with logger.Sync():
            logger.remove(".")
    logger.upload_file(caller_script)
    # the tension is in between creation vs run. Code snapshot are shared, but runs need to be unique.
    _ = dict()
    if ARGS:
        _['args'] = ARGS
    if KWARGS:
        _['kwargs'] = KWARGS

    logger.log_params(run=logger.run_info(status="created",
                                          script_path=script_path),
                      revision=logger.rev_info(),
                      fn=logger.fn_info(fn),
                      **_,
                      silent=__silent)

    logger.print(
        'taking diff, if this step takes too long, check if your '
        'uncommitted changes are too large.',
        color="green")
    logger.diff()
    if RUN.readme:
        logger.log_text(RUN.readme, "README.md", dedent=True)

    import jaynes  # now set the job name to prefix
    if jaynes.RUN.config and jaynes.RUN.mode != "local":
        runner_class, runner_args = jaynes.RUN.config['runner']
        if 'name' in runner_args:  # ssh mode does not have 'name'.
            runner_args['name'] = pJoin(file_name, RUN.JOB_NAME)
        del logger, jaynes, runner_args, runner_class
        if not __file:
            cprint(f'Set up job name', "green")

    def thunk(*args, **kwargs):
        import traceback
        from ml_logger import logger

        assert not (args and ARGS), \
            f"can not use position argument at both thunk creation as well as run.\n" \
            f"_args: {args}\n" \
            f"ARGS: {ARGS}\n"

        logger.configure(root_dir=RUN.server,
                         prefix=PREFIX,
                         register_experiment=False,
                         max_workers=10)
        logger.log_params(host=dict(hostname=logger.hostname),
                          run=dict(status="running",
                                   startTime=logger.now(),
                                   job_id=logger.job_id))

        import time
        try:
            _KWARGS = {**KWARGS}
            _KWARGS.update(**kwargs)

            results = fn(*(args or ARGS), **_KWARGS)

            logger.log_line("========== execution is complete ==========")
            logger.log_params(
                run=dict(status="completed", completeTime=logger.now()))
            logger.flush()
            time.sleep(3)
        except Exception as e:
            tb = traceback.format_exc()
            with logger.SyncContext(
            ):  # Make sure uploaded finished before termination.
                logger.print(tb, color="red")
                logger.log_text(tb, filename="traceback.err")
                logger.log_params(
                    run=dict(status="error", exitTime=logger.now()))
                logger.flush()
            time.sleep(3)
            raise e

        return results

    return thunk
示例#10
0
def train(deps=None, **kwargs):
    from ml_logger import logger
    from dmc_gen.config import Args

    Args._update(deps, **kwargs)
    logger.log_params(Args=vars(Args))

    utils.set_seed_everywhere(Args.seed)
    wrappers.VideoWrapper.prefix = wrappers.ColorWrapper.prefix = DMCGEN_DATA

    # Initialize environments
    image_size = 84 if Args.algo == 'sac' else 100
    env = wrappers.make_env(
        domain_name=Args.domain,
        task_name=Args.task,
        seed=Args.seed,
        episode_length=Args.episode_length,
        action_repeat=Args.action_repeat,
        image_size=image_size,
    )
    test_env = wrappers.make_env(domain_name=Args.domain,
                                 task_name=Args.task,
                                 seed=Args.seed + 42,
                                 episode_length=Args.episode_length,
                                 action_repeat=Args.action_repeat,
                                 image_size=image_size,
                                 mode=Args.eval_mode)

    # Prepare agent
    cropped_obs_shape = (3 * Args.frame_stack, 84, 84)
    agent = make_agent(algo=Args.algo,
                       obs_shape=cropped_obs_shape,
                       act_shape=env.action_space.shape,
                       args=Args).to(Args.device)

    if Args.load_checkpoint:
        print('Loading from checkpoint:', Args.load_checkpoint)
        logger.load_module(agent,
                           path="models/*.pkl",
                           wd=Args.load_checkpoint,
                           map_location=Args.device)

    replay_buffer = utils.ReplayBuffer(obs_shape=env.observation_space.shape,
                                       action_shape=env.action_space.shape,
                                       capacity=Args.train_steps,
                                       batch_size=Args.batch_size)

    episode, episode_reward, episode_step, done = 0, 0, 0, True
    logger.start('train')
    for step in range(Args.start_step, Args.train_steps + 1):
        if done:
            if step > Args.start_step:
                logger.store_metrics({'dt_epoch': logger.split('train')})
                logger.log_metrics_summary(dict(step=step),
                                           default_stats='mean')

            # Evaluate agent periodically
            if step % Args.eval_freq == 0:
                logger.store_metrics(episode=episode)
                with logger.Prefix(metrics="eval/"):
                    evaluate(env,
                             agent,
                             Args.eval_episodes,
                             save_video=f"videos/{step:08d}_train.mp4")
                with logger.Prefix(metrics="test/"):
                    evaluate(test_env,
                             agent,
                             Args.eval_episodes,
                             save_video=f"videos/{step:08d}_test.mp4")
                logger.log_metrics_summary(dict(step=step),
                                           default_stats='mean')

            # Save agent periodically
            if step > Args.start_step and step % Args.save_freq == 0:
                with logger.Sync():
                    logger.save_module(agent, f"models/{step:06d}.pkl")
                if Args.save_last:
                    logger.remove(f"models/{step - Args.save_freq:06d}.pkl")
                # torch.save(agent, os.path.join(model_dir, f'{step}.pt'))

            logger.store_metrics(episode_reward=episode_reward,
                                 episode=episode + 1,
                                 prefix="train/")

            obs = env.reset()
            episode_reward, episode_step, done = 0, 0, False
            episode += 1

        # Sample action for data collection
        if step < Args.init_steps:
            action = env.action_space.sample()
        else:
            with utils.Eval(agent):
                action = agent.sample_action(obs)

        # Run training update
        if step >= Args.init_steps:
            num_updates = Args.init_steps if step == Args.init_steps else 1
            for _ in range(num_updates):
                agent.update(replay_buffer, step)

        # Take step
        next_obs, reward, done, _ = env.step(action)
        done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
            done)
        replay_buffer.add(obs, action, reward, next_obs, done_bool)
        episode_reward += reward
        obs = next_obs

        episode_step += 1

    logger.print(
        f'Completed training for {Args.domain}_{Args.task}/{Args.algo}/{Args.seed}'
    )
示例#11
0
def test_launch(**_Args):
    from ml_logger import logger
    Args.update(_Args)

    logger.configure(log_directory=Args.log_dir, prefix=Args.log_prefix)
    logger.print('yo!!! diz is vorking!')
示例#12
0
def test_capture_error():
    with logger.capture_error():
        raise RuntimeError("this should not fail")

    logger.print("works!", color="green")
示例#13
0
                      act_space=tasks.envs.action_space)
        summary = tf.summary.FileWriter(config.RUN.log_directory,
                                        tf.get_default_graph())
        summary.flush()
        trainer = Trainer()
        U.initialize()
        trainer.train(tasks=tasks, maml=maml, test_tasks=test_tasks)
        # logger.clear_callback()

    tf.reset_default_graph()


def launch(**_G):
    from datetime import datetime
    now = datetime.now()
    config.G.update(_G)
    config.RUN.log_dir = "http://54.71.92.65:8081"
    config.RUN.log_prefix = f"ge_maml/{now:%Y-%m-%d}"


if __name__ == '__main__':
    import traceback

    try:
        run_e_maml()
    except Exception as e:
        tb = traceback.format_exc()
        logger.print(tb)
        logger.print(U.ALREADY_INITIALIZED)
        raise e