Exemplo n.º 1
0
def main(goal_idx=0, args=args):
    variant = default_config
    if args.config:
        with open(os.path.join(args.config)) as f:
            exp_params = json.load(f)
        variant = deep_update_dict(exp_params, variant)
    # variant['util_params']['gpu_id'] = gpu
    env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']))
    env.reset_task(goal_idx)
    experiment(env=env, goal_idx=goal_idx)
Exemplo n.º 2
0
def experiment(variant, cfg=cfg, goal_idx=0, seed=0, eval=False):
    env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']))
    if seed is not None:
        env.seed(seed)
    env.reset_task(goal_idx)
    if "cuda" in cfg.device:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_id)
    # NOTE: for new environment variable to be effective, torch should be imported after assignment
    from rlkit.torch.sac.pytorch_sac.train import Workspace
    workspace = Workspace(cfg=cfg,
                          env=env,
                          env_name=variant['env_name'],
                          goal_idx=goal_idx)
    if eval:
        print('evaluate:')
        workspace.run_evaluate()
    else:
        workspace.run()