Exemplo n.º 1
0
def parse_args():
    from docopt import docopt

    args = docopt(__doc__)
    args = {k.strip("--").replace("-", "_"): v for k, v in args.items()}
    del args["h"]
    del args["help"]

    args = DotDict(args)

    args.nb_node = int(args.nb_node)
    args.node_rank = int(args.node_rank)
    args.nb_proc = int(args.nb_proc)
    args.master_port = int(args.master_port)

    if args.resume:
        args.resume = parse_path(args.resume)
        return args

    if args.config:
        args.config = parse_path(args.config)

    args.logdir = parse_path(args.logdir)
    args.nb_env = int(args.nb_env)
    args.seed = int(args.seed)
    args.nb_step = int(float(args.nb_step))
    args.tag = parse_none(args.tag)
    args.nb_eval_env = int(args.nb_eval_env)
    args.summary_freq = int(args.summary_freq)
    args.lr = float(args.lr)
    args.epoch_len = int(float(args.epoch_len))
    args.profile = bool(args.profile)
    return args
Exemplo n.º 2
0
    def from_resume(mode, args):
        """
        :param mode: Script name
        :param args: Dict[str, Any], static args
        :return: args, log_id, initial_step_count
        """
        resume = args.resume
        log_dir_helper = LogDirHelper(args.resume)
        with open(log_dir_helper.args_file_path(), "r") as args_file:
            args = DotDict(json.load(args_file))
            args.resume = resume

        args.load_network = log_dir_helper.latest_network_path()
        args.load_optim = log_dir_helper.latest_optim_path()
        initial_step_count = log_dir_helper.latest_epoch()

        if args.agent:
            name = args.agent
        else:
            name = args.actor_host

        log_id = Init.make_log_id(
            args.tag,
            mode,
            name,
            args.netbody,
            timestamp=log_dir_helper.timestamp(),
        )
        log_id_path = Init.log_id_dir(args.logdir, args.env, log_id)
        return args, log_id_path, initial_step_count
Exemplo n.º 3
0
def parse_args():
    from docopt import docopt
    args = docopt(__doc__)
    args = {k.strip('--').replace('-', '_'): v for k, v in args.items()}
    del args['h']
    del args['help']
    args = DotDict(args)

    # Ignore other args if resuming
    if args.resume:
        args.resume = parse_path(args.resume)
        return args

    if args.config:
        args.config = parse_path(args.config)

    args.logdir = parse_path(args.logdir)
    args.gpu_id = int(args.gpu_id)
    args.nb_env = int(args.nb_env)
    args.seed = int(args.seed)
    args.nb_step = int(float(args.nb_step))
    args.tag = parse_none(args.tag)
    args.nb_eval_env = int(args.nb_eval_env)
    args.summary_freq = int(args.summary_freq)
    args.lr = float(args.lr)
    args.warmup = int(float(args.warmup))
    args.epoch_len = int(float(args.epoch_len))
    args.profile = bool(args.profile)
    return args
Exemplo n.º 4
0
def parse_args():
    from docopt import docopt
    args = docopt(__doc__)
    args = {k.strip('--').replace('-', '_'): v for k, v in args.items()}
    del args['h']
    del args['help']
    args = DotDict(args)

    # Ignore other args if resuming
    if args.resume:
        args.resume = parse_path(args.resume)
        return args

    if args.config:
        args.config = parse_path(args.config)

    args.logdir = parse_path(args.logdir)
    args.nb_env = int(args.nb_env)
    args.seed = int(args.seed)
    args.nb_step = int(float(args.nb_step))
    args.tag = parse_none(args.tag)
    args.summary_freq = int(args.summary_freq)
    args.lr = float(args.lr)
    args.epoch_len = int(float(args.epoch_len))
    args.profile = bool(args.profile)

    args.ray_addr = parse_none(args.ray_addr)
    args.nb_learners = int(args.nb_learners)
    args.nb_workers = int(args.nb_workers)
    args.learner_cpu_alloc = int(args.learner_cpu_alloc)
    args.learner_gpu_alloc = float(args.learner_gpu_alloc)
    args.worker_cpu_alloc = int(args.worker_cpu_alloc)
    args.worker_gpu_alloc = float(args.worker_gpu_alloc)

    args.nb_learn_batch = int(args.nb_learn_batch)
    args.rollout_queue_size = int(args.rollout_queue_size)

    # arg checking
    assert args.nb_learn_batch <= args.nb_workers, 'WARNING: nb_learn_batch must be <= nb_workers. Got {} <= {}' \
           .format(args.nb_learn_batch, args.nb_workers)
    return args