Ejemplo n.º 1
0
def load_policy(logdir):
    config = os.path.join(logdir, 'logs/config.gin')
    dl.load_config(config)
    alg = ResidualSAC(logdir)
    pi = alg.pi
    pi.eval()
    frameskip = alg.env.unwrapped.envs[0].frameskip
    max_torque = alg.env.unwrapped.envs[0].max_torque
    return pi, alg.device, frameskip, max_torque
Ejemplo n.º 2
0
        def test(self):
            """Test."""
            config = \
                "FeedForwardNet.activation_fn = @F.relu \n \
                FeedForwardNet.activate_last = False \n \
                FeedForwardNet.units = [64,64,1] \n \
                FeedForwardNet.in_shape = 2304 \n \
                Conv2dNet.activation_fn = @F.relu \n \
                Conv2dNet.activate_last = True \n \
                Conv2dNet.in_channels = 3 \n \
                Conv2dNet.convs = [(16,3), (32,3,2), (64,3,1,1)] \n"

            with open('./test.gin', 'w') as f:
                f.write(config)
            load_config('./test.gin')

            import numpy as np

            class Net(nn.Module):
                def __init__(self):
                    super().__init__()
                    self.conv = Conv2dNet()
                    self.ff = FeedForwardNet()

                def forward(self, x):
                    x = self.conv(x)
                    x = x.view(-1, 64 * 6 * 6)
                    return self.ff(x)

            net = Net()
            assert net.conv.conv2d0.kernel_size == (3, 3)
            assert net.conv.conv2d1.kernel_size == (3, 3)
            assert net.conv.conv2d2.kernel_size == (3, 3)

            assert net.conv.conv2d0.stride == (1, 1)
            assert net.conv.conv2d1.stride == (2, 2)
            assert net.conv.conv2d2.stride == (1, 1)

            assert net.conv.conv2d2.padding == (1, 1)

            assert net.ff.fc0.in_features == 2304
            assert net.ff.fc1.in_features == 64
            assert net.ff.fc2.in_features == 64
            assert net.ff.fc2.out_features == 1
            os.remove('./test.gin')
Ejemplo n.º 3
0
def _init_env_and_policy(difficulty, policy):
    if policy == 'ppo':
        if difficulty == 4:
            expdir = f'../../../models/mpfc_level_{difficulty}'
            is_level_4 = True
        else:
            expdir = f'../../../models/fc_level_{difficulty}'
            is_level_4 = False
        bindings = [
            f'make_pybullet_env.reward_fn="task{difficulty}_competition_reward"',
            'make_pybullet_env.termination_fn="{}"'.format(
                'stay_close_to_goal_level_4'
                if is_level_4 else 'stay_close_to_goal'),
            f'make_pybullet_env.initializer="task{difficulty}_init"',
            'make_pybullet_env.visualization=True',
            'make_pybullet_env.monitor=True',
        ]
        from rrc_simulation.code.utils import set_seed
        set_seed(0)
        dl.load_config(expdir + '/config.gin', bindings)
        ppo = ResidualPPO2(expdir, nenv=1)
        ppo.load()
        env = ppo.env
        set_env_to_eval_mode(env)
        return env, ppo

    else:
        eval_config = {
            'action_space':
            'torque_and_position' if args.policy == 'mpfc' else 'torque',
            'frameskip': 3,
            'residual': True,
            'reward_fn': f'task{difficulty}_competition_reward',
            'termination_fn': 'no_termination',
            'initializer': f'task{difficulty}_init',
            'monitor': True,
            'rank': 0
        }

        from rrc_simulation.code.utils import set_seed
        set_seed(0)
        env = make_training_env(visualization=False, **eval_config)
        return env, None
Ejemplo n.º 4
0
def _load_env_and_policy(logdir, t=None):

    gin_bindings = [
        "make_training_env.sim=True",
        "make_training_env.visualization=True",
        "make_training_env.monitor=True",
    ]

    config = os.path.join(logdir, 'config.gin')
    dl.load_config(config, gin_bindings)
    alg = ResidualSAC(logdir)
    alg.load(t)
    env = alg.env
    pi = alg.pi
    dl.rl.set_env_to_eval_mode(env)
    pi.eval()
    init_ob = alg.data_manager._ob
    if t is None:
        try:
            t = max(alg.ckptr.ckpts())
        except Exception:
            t = 0
    return env, pi, alg.device, init_ob, t
Ejemplo n.º 5
0
import argparse
import dl

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Agent.')
    parser.add_argument('--expdir', type=str, help='expdir', required=True)
    parser.add_argument('--gin_config',
                        type=str,
                        help='gin config',
                        required=True)
    parser.add_argument('-b',
                        '--gin_bindings',
                        nargs='+',
                        help='gin bindings to overwrite config')
    args = parser.parse_args()
    dl.load_config(args.gin_config, args.gin_bindings)
    dl.train(args.expdir)
Ejemplo n.º 6
0
import argparse
import dl
import yaml

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Agent.')
    parser.add_argument('config', type=str, help='config')
    args = parser.parse_args()
    with open(args.config, 'r') as f:
        config = yaml.load(f)

    gin_bindings = []
    for k, v in config['gin_bindings'].items():
        if isinstance(v, str) and v[0] != '@':
            gin_bindings.append(f'{k}="{v}"')
        else:
            gin_bindings.append(f"{k}={v}")
    dl.load_config(config['base_config'], gin_bindings)
    dl.train(config['logdir'])
Ejemplo n.º 7
0
"""Render trained agents."""
import dl
import argparse
from dl.rl import rl_record, misc, PPO, VecFrameStack
import residual_shared_autonomy.drone_sim
import os

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Script.')
    parser.add_argument('logdir', type=str, help='logdir')
    args = parser.parse_args()

    dl.load_config(os.path.join(args.logdir, 'config.gin'))

    t = PPO(args.logdir, nenv=1)
    t.load()
    t.pi.eval()
    env = t.env
    misc.set_env_to_eval_mode(env)
    os.makedirs(os.path.join(t.logdir, 'video'), exist_ok=True)
    outfile = os.path.join(t.logdir, 'video',
                           t.ckptr.format.format(t.t) + '.mp4')
    rl_record(env, t.pi, 10, outfile, t.device)
    t.close()
Ejemplo n.º 8
0
"""Main script for training models."""
import dl
import argparse

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Script.')
    parser.add_argument('logdir', type=str, help='logdir')
    parser.add_argument('config', type=str, help='gin config')
    parser.add_argument('-b',
                        '--bindings',
                        nargs='+',
                        help='gin bindings to overwrite config')
    args = parser.parse_args()

    dl.load_config(args.config, args.bindings)
    dl.train(args.logdir)
Ejemplo n.º 9
0
from residual_shared_autonomy.drone_sim import DroneJoystickActor
import os

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Script.')
    parser.add_argument('logdir', type=str, help='logdir')
    parser.add_argument('--drone',
                        action='store_true',
                        help='conrol drone env')
    parser.add_argument('--reacher',
                        action='store_true',
                        help='conrol luanr reacher env')
    args = parser.parse_args()

    if args.drone:
        dl.load_config(os.path.join(args.logdir, 'config.gin'),
                       ['make_env.env_id="DroneReacherBot-v0"'])
        trainer = ConstrainedResidualPPO(args.logdir,
                                         nenv=1,
                                         base_actor_cls=DroneJoystickActor)
    elif args.reacher:
        dl.load_config(os.path.join(args.logdir, 'config.gin'),
                       ['make_env.env_id="LunarLanderReacher-v2"'])
        trainer = ConstrainedResidualPPO(
            args.logdir, nenv=1, base_actor_cls=LunarLanderJoystickActor)
    else:
        dl.load_config(os.path.join(args.logdir, 'config.gin'),
                       ['make_env.env_id="LunarLanderRandomContinuous-v2"'])
        trainer = ConstrainedResidualPPO(
            args.logdir, nenv=1, base_actor_cls=LunarLanderJoystickActor)
    trainer.load()
    trainer.evaluate()