コード例 #1
0
ファイル: 1.train.py プロジェクト: liusida/thesis-bodies
import common.wrapper as wrapper
import common.gym_interface as gym_interface
import common.callbacks as callbacks
from common.activation_fn import MyThreshold

if __name__ == "__main__":

    args = common.args
    print(args)

    # SAC.learn need this. If use SubprocVecEnv instead of DummyVecEnv, you need to seed in each subprocess.
    set_random_seed(common.seed)

    saved_model_filename = common.build_model_filename(args)

    hyperparams = common.load_hyperparameters(conf_name="PPO")
    print(hyperparams)

    # Make every env has the same obs space and action space
    default_wrapper = []
    # if padding zero:
    #   default_wrapper.append(wrapper.WalkerWrapper)

    if args.realign_method != "":
        default_wrapper.append(wrapper.ReAlignedWrapper)

    assert len(args.train_bodies) > 0, "No body to train."
    if args.with_bodyinfo:
        default_wrapper.append(wrapper.BodyinfoWrapper)

    print("Making train environments...")
コード例 #2
0
from common.cnspns import CNSPNSPPO, CNSPNSPolicy

if __name__ == "__main__":

    args = common.args
    print(args)

    args.test_bodies = args.train_bodies # omit test_bodies in command from now on.
    args.initialize_weights_from = args.model_filename

    # SAC.learn need this. If use SubprocVecEnv instead of DummyVecEnv, you need to seed in each subprocess.
    set_random_seed(common.seed)

    saved_model_filename = common.build_model_filename(args)

    hyperparams = common.load_hyperparameters(conf_name=args.rl_hyperparameter)
    print(hyperparams)

    # Make every env has the same obs space and action space
    default_wrapper = []
    # if padding zero:
    #   default_wrapper.append(wrapper.WalkerWrapper)

    if args.topology_wrapper == "same":
        body_type = 0
        for body in args.train_bodies + args.test_bodies:
            if body_type == 0:
                body_type = body//100
            else:
                assert body_type == body//100, "Training on different body types."
        if args.realign_method != "":
コード例 #3
0
ファイル: 1.train.py プロジェクト: liusida/thesis-bodies
import common.wrapper as wrapper
import common.gym_interface as gym_interface
import common.callbacks as callbacks
from common.activation_fn import MyThreshold

if __name__ == "__main__":

    args = common.args
    print(args)

    # PPO.learn need this. If use SubprocVecEnv instead of DummyVecEnv, you need to seed in each subprocess.
    set_random_seed(common.seed)

    saved_model_filename = common.build_model_filename(args)

    hyperparams = common.load_hyperparameters()
    print(hyperparams)

    # Make every env has the same obs space and action space
    default_wrapper = [wrapper.WalkerWrapper]

    assert len(args.train_bodies) > 0, "No body to train."
    if args.with_bodyinfo:
        default_wrapper += [wrapper.BodyinfoWrapper]
    venv = DummyVecEnv([
        gym_interface.make_env(
            rank=i,
            seed=common.seed,
            wrappers=default_wrapper,
            render=args.render,
            robot_body=args.train_bodies[i % len(args.train_bodies)])
コード例 #4
0
ファイル: 30.1.train.py プロジェクト: liusida/thesis-bodies
import common.wrapper as wrapper
import common.gym_interface as gym_interface
import common.callbacks as callbacks
from common.activation_fn import MyThreshold
from common.pns import PNSPPO, PNSMlpPolicy

if __name__ == "__main__":

    args = common.args
    print(args)

    # args.vec_normalize = True # Robo need normalization.

    saved_model_filename = common.build_model_filename(args)

    hyperparams = common.load_hyperparameters(conf_name="Robo")
    # Overwrite learning_rate using args:
    # hyperparams["learning_rate"] = common.args.learning_rate

    print(hyperparams)

    # Make every env has the same obs space and action space
    default_wrapper = []
    # if padding zero:
    #   default_wrapper.append(wrapper.WalkerWrapper)

    if args.topology_wrapper == "same":
        body_type = 0
        for body in args.train_bodies + args.test_bodies:
            if body_type == 0:
                body_type = body // 100