コード例 #1
0
def torch_upload():
    from ml_logger import logger
    import numpy as np

    logger.configure(root_dir="http://54.71.92.65:9080", prefix="geyang/ml_logger-debug/test-1",
                     register_experiment=True)
    logger.log_params(args={})

    with logger.Sync():
        import os
        import torch
        from pycurl import Curl
        from tempfile import NamedTemporaryFile

        logger.remove('upload/example.pt')

        with NamedTemporaryFile(delete=True) as f:
            torch.save(np.ones([10_000_000]), f)
            # torch.save(np.ones([1000_000]), f)
            logger.print(f.name)

            c = Curl()
            c.setopt(c.URL, logger.root_dir)
            # proxy = os.environ.get('HTTP_PROXY')
            # c.setopt(c.PROXY, proxy)
            # logger.print('proxy:', proxy)
            c.setopt(c.TIMEOUT, 100000)
            c.setopt(c.HTTPPOST, [
                ('file', (
                    c.FORM_FILE, f.name,
                    c.FORM_FILENAME, logger.prefix + '/upload/example.pt',
                    c.FORM_CONTENTTYPE, 'plain/text',
                )),
            ])
            c.perform()
            c.close()

        logger.print('done')


        # logger.remove(".")
        # a = np.ones([1, 1, 100_000_000 // 4])
        # logger.print(f"the size of the tensor is {a.size}")
        # data = dict(key="ok", large=a)
        # logger.torch_save(data, f"save/data-{logger.now('%H.%M.%S')}.pkl")
    logger.print('done')
コード例 #2
0
ファイル: __init__.py プロジェクト: geyang/dmc_gen
def instr(fn, *ARGS, __file=False, __silent=False, **KWARGS):
    """
    thunk for configuring the logger. The reason why this is not a decorator is

    :param fn: function to be called
    :param *ARGS: position arguments for the call
    :param __file__: console mode, by-pass file related logging
    :param __silent: do not print
    :param **KWARGS: keyword arguments for the call
    :return: a thunk that can be called without parameters
    """
    from ml_logger import logger

    if __file:
        caller_script = pJoin(os.getcwd(), __file)
    else:
        launch_module = inspect.getmodule(inspect.stack()[1][0])
        __file = launch_module.__file__
        caller_script = abspath(__file)

    # note: for scripts in the `plan2vec` module this also works -- b/c we truncate fixed depth.
    script_path = logger.truncate(caller_script,
                                  depth=len(__file__.split('/')) - 1)
    file_stem = logger.stem(script_path)
    file_name = basename(file_stem)

    RUN(file_name=file_name, file_stem=file_stem, now=logger.now())

    PREFIX = RUN.PREFIX

    # todo: there should be a better way to log these.
    # todo: we shouldn't need to log to the same directory, and the directory for the run shouldn't be fixed.
    logger.configure(
        root_dir=RUN.server,
        prefix=PREFIX,
        asynchronous=False,  # use sync logger
        max_workers=4,
        register_experiment=False)
    if RUN.restart:
        with logger.Sync():
            logger.remove(".")
    logger.upload_file(caller_script)
    # the tension is in between creation vs run. Code snapshot are shared, but runs need to be unique.
    _ = dict()
    if ARGS:
        _['args'] = ARGS
    if KWARGS:
        _['kwargs'] = KWARGS

    logger.log_params(run=logger.run_info(status="created",
                                          script_path=script_path),
                      revision=logger.rev_info(),
                      fn=logger.fn_info(fn),
                      **_,
                      silent=__silent)

    logger.print(
        'taking diff, if this step takes too long, check if your '
        'uncommitted changes are too large.',
        color="green")
    logger.diff()
    if RUN.readme:
        logger.log_text(RUN.readme, "README.md", dedent=True)

    import jaynes  # now set the job name to prefix
    if jaynes.RUN.config and jaynes.RUN.mode != "local":
        runner_class, runner_args = jaynes.RUN.config['runner']
        if 'name' in runner_args:  # ssh mode does not have 'name'.
            runner_args['name'] = pJoin(file_name, RUN.JOB_NAME)
        del logger, jaynes, runner_args, runner_class
        if not __file:
            cprint(f'Set up job name', "green")

    def thunk(*args, **kwargs):
        import traceback
        from ml_logger import logger

        assert not (args and ARGS), \
            f"can not use position argument at both thunk creation as well as run.\n" \
            f"_args: {args}\n" \
            f"ARGS: {ARGS}\n"

        logger.configure(root_dir=RUN.server,
                         prefix=PREFIX,
                         register_experiment=False,
                         max_workers=10)
        logger.log_params(host=dict(hostname=logger.hostname),
                          run=dict(status="running",
                                   startTime=logger.now(),
                                   job_id=logger.job_id))

        import time
        try:
            _KWARGS = {**KWARGS}
            _KWARGS.update(**kwargs)

            results = fn(*(args or ARGS), **_KWARGS)

            logger.log_line("========== execution is complete ==========")
            logger.log_params(
                run=dict(status="completed", completeTime=logger.now()))
            logger.flush()
            time.sleep(3)
        except Exception as e:
            tb = traceback.format_exc()
            with logger.SyncContext(
            ):  # Make sure uploaded finished before termination.
                logger.print(tb, color="red")
                logger.log_text(tb, filename="traceback.err")
                logger.log_params(
                    run=dict(status="error", exitTime=logger.now()))
                logger.flush()
            time.sleep(3)
            raise e

        return results

    return thunk
コード例 #3
0
def train(deps=None, **kwargs):
    from ml_logger import logger
    from dmc_gen.config import Args

    Args._update(deps, **kwargs)
    logger.log_params(Args=vars(Args))

    utils.set_seed_everywhere(Args.seed)
    wrappers.VideoWrapper.prefix = wrappers.ColorWrapper.prefix = DMCGEN_DATA

    # Initialize environments
    image_size = 84 if Args.algo == 'sac' else 100
    env = wrappers.make_env(
        domain_name=Args.domain,
        task_name=Args.task,
        seed=Args.seed,
        episode_length=Args.episode_length,
        action_repeat=Args.action_repeat,
        image_size=image_size,
    )
    test_env = wrappers.make_env(domain_name=Args.domain,
                                 task_name=Args.task,
                                 seed=Args.seed + 42,
                                 episode_length=Args.episode_length,
                                 action_repeat=Args.action_repeat,
                                 image_size=image_size,
                                 mode=Args.eval_mode)

    # Prepare agent
    cropped_obs_shape = (3 * Args.frame_stack, 84, 84)
    agent = make_agent(algo=Args.algo,
                       obs_shape=cropped_obs_shape,
                       act_shape=env.action_space.shape,
                       args=Args).to(Args.device)

    if Args.load_checkpoint:
        print('Loading from checkpoint:', Args.load_checkpoint)
        logger.load_module(agent,
                           path="models/*.pkl",
                           wd=Args.load_checkpoint,
                           map_location=Args.device)

    replay_buffer = utils.ReplayBuffer(obs_shape=env.observation_space.shape,
                                       action_shape=env.action_space.shape,
                                       capacity=Args.train_steps,
                                       batch_size=Args.batch_size)

    episode, episode_reward, episode_step, done = 0, 0, 0, True
    logger.start('train')
    for step in range(Args.start_step, Args.train_steps + 1):
        if done:
            if step > Args.start_step:
                logger.store_metrics({'dt_epoch': logger.split('train')})
                logger.log_metrics_summary(dict(step=step),
                                           default_stats='mean')

            # Evaluate agent periodically
            if step % Args.eval_freq == 0:
                logger.store_metrics(episode=episode)
                with logger.Prefix(metrics="eval/"):
                    evaluate(env,
                             agent,
                             Args.eval_episodes,
                             save_video=f"videos/{step:08d}_train.mp4")
                with logger.Prefix(metrics="test/"):
                    evaluate(test_env,
                             agent,
                             Args.eval_episodes,
                             save_video=f"videos/{step:08d}_test.mp4")
                logger.log_metrics_summary(dict(step=step),
                                           default_stats='mean')

            # Save agent periodically
            if step > Args.start_step and step % Args.save_freq == 0:
                with logger.Sync():
                    logger.save_module(agent, f"models/{step:06d}.pkl")
                if Args.save_last:
                    logger.remove(f"models/{step - Args.save_freq:06d}.pkl")
                # torch.save(agent, os.path.join(model_dir, f'{step}.pt'))

            logger.store_metrics(episode_reward=episode_reward,
                                 episode=episode + 1,
                                 prefix="train/")

            obs = env.reset()
            episode_reward, episode_step, done = 0, 0, False
            episode += 1

        # Sample action for data collection
        if step < Args.init_steps:
            action = env.action_space.sample()
        else:
            with utils.Eval(agent):
                action = agent.sample_action(obs)

        # Run training update
        if step >= Args.init_steps:
            num_updates = Args.init_steps if step == Args.init_steps else 1
            for _ in range(num_updates):
                agent.update(replay_buffer, step)

        # Take step
        next_obs, reward, done, _ = env.step(action)
        done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
            done)
        replay_buffer.add(obs, action, reward, next_obs, done_bool)
        episode_reward += reward
        obs = next_obs

        episode_step += 1

    logger.print(
        f'Completed training for {Args.domain}_{Args.task}/{Args.algo}/{Args.seed}'
    )
コード例 #4
0
def test_torch_load_sync(setup):
    with logger.Sync():
        test_torch_save(setup)
        module = logger.torch_load("modules/test_torch_save.pkl")
    print(module)