def main(goal_idx=0, args=args): variant = default_config if args.config: with open(os.path.join(args.config)) as f: exp_params = json.load(f) variant = deep_update_dict(exp_params, variant) # variant['util_params']['gpu_id'] = gpu env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params'])) env.reset_task(goal_idx) experiment(env=env, goal_idx=goal_idx)
def experiment(variant, cfg=cfg, goal_idx=0, seed=0, eval=False): env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params'])) if seed is not None: env.seed(seed) env.reset_task(goal_idx) if "cuda" in cfg.device: os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_id) # NOTE: for new environment variable to be effective, torch should be imported after assignment from rlkit.torch.sac.pytorch_sac.train import Workspace workspace = Workspace(cfg=cfg, env=env, env_name=variant['env_name'], goal_idx=goal_idx) if eval: print('evaluate:') workspace.run_evaluate() else: workspace.run()