示例#1
0
        _w = wrapper_pns.make_same_dim_wrapper(obs_dim=28, action_dim=8)
        default_wrapper.append(_w)

    for rank_idx, test_body in enumerate(args.test_bodies):
        eval_venv = DummyVecEnv([
            gym_interface.make_env(rank=rank_idx,
                                   seed=common.seed,
                                   wrappers=default_wrapper,
                                   force_render=args.render,
                                   robot_body=test_body,
                                   dataset_folder=args.body_folder)
        ])

        if args.vec_normalize:
            eval_venv = VecNormalize.load(
                common.get_vec_pkl_from_model_filename(args.model_filename),
                eval_venv)
        if args.stack_frames > 1:
            eval_venv = VecFrameStack(eval_venv, args.stack_frames)

        eval_venv.seed(common.seed)
        if args.pns:
            model_cls = PNSPPO
            policy_cls = PNSMlpPolicy
        elif args.cnspns:
            model_cls = CNSPNSPPO
            policy_cls = CNSPNSPolicy
        else:
            model_cls = PPO
            policy_cls = "MlpPolicy"
示例#2
0
        default_wrapper.append(_w)

    assert len(args.train_bodies) > 0, "No body to train."
    if args.with_bodyinfo:
        default_wrapper.append(wrapper.BodyinfoWrapper)

    print("Making train environments...")
    venv = DummyVecEnv([gym_interface.make_env(rank=i, seed=common.seed, wrappers=default_wrapper, render=args.render,
                                               robot_body=args.train_bodies[i % len(args.train_bodies)],
                                               dataset_folder=args.body_folder) for i in range(args.num_venvs)])

    normalize_kwargs = {}
    if args.vec_normalize:
        normalize_kwargs["gamma"] = hyperparams["gamma"]
        if len(args.model_filename) > 0:
            venv = VecNormalize.load(common.get_vec_pkl_from_model_filename(args.model_filename), venv)
        else:
            venv = VecNormalize(venv, **normalize_kwargs)

    if args.stack_frames > 1:
        venv = VecFrameStack(venv, args.stack_frames)

    print("Making eval environments...")
    all_callbacks = []
    for rank_idx, test_body in enumerate(args.test_bodies):
        body_info = 0
        eval_venv = DummyVecEnv([gym_interface.make_env(rank=rank_idx, seed=common.seed+1, wrappers=default_wrapper, render=False,
                                                        robot_body=test_body, body_info=body_info,
                                                        dataset_folder=args.body_folder)])
        if args.vec_normalize:
            eval_venv = VecNormalize(eval_venv, norm_reward=False, **normalize_kwargs)