Beispiel #1
0
def main():
    env_name = 'BreakoutNoFrameskip-v4'
    env = wrap_deepmind(make_atari(env_name))
    output_size = env.action_space.n

    name = 'breakout2'
    with tf.Session() as sess:
        with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
            nenvs = 16
            minibatches = 4
            nsteps = 128
            def policy_fn(obs, nenvs):
                #return models.nature_cnn()(obs)
                return models.atari_lstm(nenvs, 512)(obs)
                #return models.mlp()(obs)
            network = models.ACnet(env.observation_space.shape, policy_fn, nenvs, nsteps, minibatches, actiontype.Discrete, output_size, recurrent=True)

            model = PPO(sess, network, epochs=4, epsilon=0.1,\
                learning_rate=lambda f : 2.5e-4 * (1-f), name=name)
            
        train(sess, model, env_name, 10e6, atari=True)
        #run_only(sess, model, env, render=True)
        env.close()
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, required=True)
    parser.add_argument("--sim-gpu-id", type=int, required=True)
    parser.add_argument("--pth-gpu-id", type=int, required=True)
    parser.add_argument("--num-processes", type=int, required=True)
    parser.add_argument("--hidden-size", type=int, default=512)
    parser.add_argument("--count-test-episodes", type=int, default=100)
    parser.add_argument(
        "--sensors",
        type=str,
        default="RGB_SENSOR,DEPTH_SENSOR",
        help="comma separated string containing different"
        "sensors to use, currently 'RGB_SENSOR' and"
        "'DEPTH_SENSOR' are supported",
    )
    parser.add_argument(
        "--task-config",
        type=str,
        default="configs/tasks/pointnav.yaml",
        help="path to config yaml containing information about task",
    )
    args = parser.parse_args()

    device = torch.device("cuda:{}".format(args.pth_gpu_id))

    env_configs = []
    baseline_configs = []

    for _ in range(args.num_processes):
        config_env = get_config(config_paths=args.task_config)
        config_env.defrost()
        config_env.DATASET.SPLIT = "val"

        agent_sensors = args.sensors.strip().split(",")
        for sensor in agent_sensors:
            assert sensor in ["RGB_SENSOR", "DEPTH_SENSOR"]
        config_env.SIMULATOR.AGENT_0.SENSORS = agent_sensors
        config_env.freeze()
        env_configs.append(config_env)

        config_baseline = cfg_baseline()
        baseline_configs.append(config_baseline)

    assert len(baseline_configs) > 0, "empty list of datasets"

    envs = habitat.VectorEnv(
        make_env_fn=make_env_fn,
        env_fn_args=tuple(
            tuple(zip(env_configs, baseline_configs,
                      range(args.num_processes)))),
    )

    ckpt = torch.load(args.model_path, map_location=device)

    actor_critic = Policy(
        observation_space=envs.observation_spaces[0],
        action_space=envs.action_spaces[0],
        hidden_size=512,
        goal_sensor_uuid=env_configs[0].TASK.GOAL_SENSOR_UUID,
    )
    actor_critic.to(device)

    ppo = PPO(
        actor_critic=actor_critic,
        clip_param=0.1,
        ppo_epoch=4,
        num_mini_batch=32,
        value_loss_coef=0.5,
        entropy_coef=0.01,
        lr=2.5e-4,
        eps=1e-5,
        max_grad_norm=0.5,
    )

    ppo.load_state_dict(ckpt["state_dict"])

    actor_critic = ppo.actor_critic

    observations = envs.reset()
    batch = batch_obs(observations)
    for sensor in batch:
        batch[sensor] = batch[sensor].to(device)

    episode_rewards = torch.zeros(envs.num_envs, 1, device=device)
    episode_spls = torch.zeros(envs.num_envs, 1, device=device)
    episode_success = torch.zeros(envs.num_envs, 1, device=device)
    episode_counts = torch.zeros(envs.num_envs, 1, device=device)
    current_episode_reward = torch.zeros(envs.num_envs, 1, device=device)

    test_recurrent_hidden_states = torch.zeros(args.num_processes,
                                               args.hidden_size,
                                               device=device)
    not_done_masks = torch.zeros(args.num_processes, 1, device=device)

    while episode_counts.sum() < args.count_test_episodes:
        with torch.no_grad():
            _, actions, _, test_recurrent_hidden_states = actor_critic.act(
                batch,
                test_recurrent_hidden_states,
                not_done_masks,
                deterministic=False,
            )

        outputs = envs.step([a[0].item() for a in actions])

        observations, rewards, dones, infos = [list(x) for x in zip(*outputs)]
        batch = batch_obs(observations)
        for sensor in batch:
            batch[sensor] = batch[sensor].to(device)

        not_done_masks = torch.tensor(
            [[0.0] if done else [1.0] for done in dones],
            dtype=torch.float,
            device=device,
        )

        for i in range(not_done_masks.shape[0]):
            if not_done_masks[i].item() == 0:
                episode_spls[i] += infos[i]["roomnavmetric"]
                if infos[i]["roomnavmetric"] > 0:
                    episode_success[i] += 1

        rewards = torch.tensor(rewards, dtype=torch.float,
                               device=device).unsqueeze(1)
        current_episode_reward += rewards
        episode_rewards += (1 - not_done_masks) * current_episode_reward
        episode_counts += 1 - not_done_masks
        current_episode_reward *= not_done_masks

    episode_reward_mean = (episode_rewards / episode_counts).mean().item()
    episode_spl_mean = (episode_spls / episode_counts).mean().item()
    episode_success_mean = (episode_success / episode_counts).mean().item()

    print("Average episode reward: {:.6f}".format(episode_reward_mean))
    print("Average episode success: {:.6f}".format(episode_success_mean))
    print("Average episode spl: {:.6f}".format(episode_spl_mean))
Beispiel #3
0
def main():
    parser = ppo_args()
    args = parser.parse_args()

    random.seed(args.seed)

    device = torch.device("cuda:{}".format(args.pth_gpu_id))

    logger.add_filehandler(args.log_file)

    if not os.path.isdir(args.checkpoint_folder):
        os.makedirs(args.checkpoint_folder)

    for p in sorted(list(vars(args))):
        logger.info("{}: {}".format(p, getattr(args, p)))

    envs = construct_envs(args)

    actor_critic = Policy(
        observation_space=envs.observation_spaces[0],
        action_space=envs.action_spaces[0],
        hidden_size=args.hidden_size,
    )
    actor_critic.to(device)

    agent = PPO(
        actor_critic,
        args.clip_param,
        args.ppo_epoch,
        args.num_mini_batch,
        args.value_loss_coef,
        args.entropy_coef,
        lr=args.lr,
        eps=args.eps,
        max_grad_norm=args.max_grad_norm,
    )

    logger.info("agent number of parameters: {}".format(
        sum(param.numel() for param in agent.parameters())))

    observations = envs.reset()

    batch = batch_obs(observations)

    rollouts = RolloutStorage(
        args.num_steps,
        envs.num_envs,
        envs.observation_spaces[0],
        envs.action_spaces[0],
        args.hidden_size,
    )
    for sensor in rollouts.observations:
        rollouts.observations[sensor][0].copy_(batch[sensor])
    rollouts.to(device)

    episode_rewards = torch.zeros(envs.num_envs, 1)
    episode_counts = torch.zeros(envs.num_envs, 1)
    current_episode_reward = torch.zeros(envs.num_envs, 1)
    window_episode_reward = deque()
    window_episode_counts = deque()

    t_start = time()
    env_time = 0
    pth_time = 0
    count_steps = 0
    count_checkpoints = 0

    for update in range(args.num_updates):
        if args.use_linear_lr_decay:
            update_linear_schedule(agent.optimizer, update, args.num_updates,
                                   args.lr)

        agent.clip_param = args.clip_param * (1 - update / args.num_updates)

        for step in range(args.num_steps):
            t_sample_action = time()
            # sample actions
            with torch.no_grad():
                step_observation = {
                    k: v[step]
                    for k, v in rollouts.observations.items()
                }

                (
                    values,
                    actions,
                    actions_log_probs,
                    recurrent_hidden_states,
                ) = actor_critic.act(
                    step_observation,
                    rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step],
                )
            pth_time += time() - t_sample_action

            t_step_env = time()

            outputs = envs.step([a[0].item() for a in actions])
            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]

            env_time += time() - t_step_env

            t_update_stats = time()
            batch = batch_obs(observations)
            rewards = torch.tensor(rewards, dtype=torch.float)
            rewards = rewards.unsqueeze(1)

            masks = torch.tensor([[0.0] if done else [1.0] for done in dones],
                                 dtype=torch.float)

            current_episode_reward += rewards
            episode_rewards += (1 - masks) * current_episode_reward
            episode_counts += 1 - masks
            current_episode_reward *= masks

            rollouts.insert(
                batch,
                recurrent_hidden_states,
                actions,
                actions_log_probs,
                values,
                rewards,
                masks,
            )

            count_steps += envs.num_envs
            pth_time += time() - t_update_stats

        if len(window_episode_reward) == args.reward_window_size:
            window_episode_reward.popleft()
            window_episode_counts.popleft()
        window_episode_reward.append(episode_rewards.clone())
        window_episode_counts.append(episode_counts.clone())

        t_update_model = time()
        with torch.no_grad():
            last_observation = {
                k: v[-1]
                for k, v in rollouts.observations.items()
            }
            next_value = actor_critic.get_value(
                last_observation,
                rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1],
            ).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()
        pth_time += time() - t_update_model

        # log stats
        if update > 0 and update % args.log_interval == 0:
            logger.info("update: {}\tfps: {:.3f}\t".format(
                update, count_steps / (time() - t_start)))

            logger.info("update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                        "frames: {}".format(update, env_time, pth_time,
                                            count_steps))

            window_rewards = (window_episode_reward[-1] -
                              window_episode_reward[0]).sum()
            window_counts = (window_episode_counts[-1] -
                             window_episode_counts[0]).sum()

            if window_counts > 0:
                logger.info("Average window size {} reward: {:3f}".format(
                    len(window_episode_reward),
                    (window_rewards / window_counts).item(),
                ))
            else:
                logger.info("No episodes finish in current window")

        # checkpoint model
        if update % args.checkpoint_interval == 0:
            checkpoint = {"state_dict": agent.state_dict()}
            torch.save(
                checkpoint,
                os.path.join(
                    args.checkpoint_folder,
                    "ckpt.{}.pth".format(count_checkpoints),
                ),
            )
            count_checkpoints += 1
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, required=True)
    parser.add_argument("--sim-gpu-id", type=int, required=True)
    parser.add_argument("--pth-gpu-id", type=int, required=True)
    parser.add_argument("--num-processes", type=int, required=True)
    parser.add_argument("--hidden-size", type=int, default=512)
    parser.add_argument("--count-test-episodes", type=int, default=100)
    parser.add_argument(
        "--sensors",
        type=str,
        default="DEPTH_SENSOR",
        help="comma separated string containing different"
        "sensors to use, currently 'RGB_SENSOR' and"
        "'DEPTH_SENSOR' are supported",
    )
    parser.add_argument(
        "--task-config",
        type=str,
        default="configs/tasks/pointnav.yaml",
        help="path to config yaml containing information about task",
    )

    cmd_line_inputs = [
        "--model-path",
        "/home/bruce/NSERC_2019/habitat-api/data/checkpoints/depth.pth",
        "--sim-gpu-id",
        "0",
        "--pth-gpu-id",
        "0",
        "--num-processes",
        "1",
        "--count-test-episodes",
        "100",
        "--task-config",
        "configs/tasks/pointnav.yaml",
    ]
    args = parser.parse_args(cmd_line_inputs)

    device = torch.device("cuda:{}".format(args.pth_gpu_id))

    env_configs = []
    baseline_configs = []

    for _ in range(args.num_processes):
        config_env = get_config(config_paths=args.task_config)
        config_env.defrost()
        config_env.DATASET.SPLIT = "val"

        agent_sensors = args.sensors.strip().split(",")
        for sensor in agent_sensors:
            assert sensor in ["RGB_SENSOR", "DEPTH_SENSOR"]
        config_env.SIMULATOR.AGENT_0.SENSORS = agent_sensors
        config_env.freeze()
        env_configs.append(config_env)

        config_baseline = cfg_baseline()
        baseline_configs.append(config_baseline)

    assert len(baseline_configs) > 0, "empty list of datasets"

    envs = habitat.VectorEnv(
        make_env_fn=make_env_fn,
        env_fn_args=tuple(
            tuple(zip(env_configs, baseline_configs, range(args.num_processes)))
        ),
    )

    ckpt = torch.load(args.model_path, map_location=device)

    actor_critic = Policy(
        observation_space=envs.observation_spaces[0],
        action_space=envs.action_spaces[0],
        hidden_size=512,
        goal_sensor_uuid="pointgoal",
    )
    actor_critic.to(device)

    ppo = PPO(
        actor_critic=actor_critic,
        clip_param=0.1,
        ppo_epoch=4,
        num_mini_batch=32,
        value_loss_coef=0.5,
        entropy_coef=0.01,
        lr=2.5e-4,
        eps=1e-5,
        max_grad_norm=0.5,
    )

    ppo.load_state_dict(ckpt["state_dict"])

    actor_critic = ppo.actor_critic

    observations = envs.reset()
    batch = batch_obs(observations)
    for sensor in batch:
        batch[sensor] = batch[sensor].to(device)

    test_recurrent_hidden_states = torch.zeros(
        args.num_processes, args.hidden_size, device=device
    )
    not_done_masks = torch.zeros(args.num_processes, 1, device=device)

    def transform_callback(data):
        nonlocal actor_critic
        nonlocal batch
        nonlocal not_done_masks
        nonlocal test_recurrent_hidden_states
        global flag
        global t_prev_update
        global observation

        if flag == 2:
            observation["depth"] = np.reshape(data.data[0:-2], (256, 256, 1))
            observation["pointgoal"] = data.data[-2:]
            flag = 1
            return

        pointgoal_received = data.data[-2:]
        translate_amount = 0.25  # meters
        rotate_amount = 0.174533  # radians

        isrotated = (
            rotate_amount * 0.95
            <= abs(pointgoal_received[1] - observation["pointgoal"][1])
            <= rotate_amount * 1.05
        )
        istimeup = (time.time() - t_prev_update) >= 4

        # print('istranslated is '+ str(istranslated))
        # print('isrotated is '+ str(isrotated))
        # print('istimeup is '+ str(istimeup))

        if isrotated or istimeup:
            vel_msg = Twist()
            vel_msg.linear.x = 0
            vel_msg.linear.y = 0
            vel_msg.linear.z = 0
            vel_msg.angular.x = 0
            vel_msg.angular.y = 0
            vel_msg.angular.z = 0
            pub_vel.publish(vel_msg)
            time.sleep(0.2)
            print("entered update step")

            # cv2.imshow("Depth", observation['depth'])
            # cv2.waitKey(100)

            observation["depth"] = np.reshape(data.data[0:-2], (256, 256, 1))
            observation["pointgoal"] = data.data[-2:]

            batch = batch_obs([observation])
            for sensor in batch:
                batch[sensor] = batch[sensor].to(device)
            if flag == 1:
                not_done_masks = torch.tensor([0.0], dtype=torch.float, device=device)
                flag = 0
            else:
                not_done_masks = torch.tensor([1.0], dtype=torch.float, device=device)

            _, actions, _, test_recurrent_hidden_states = actor_critic.act(
                batch, test_recurrent_hidden_states, not_done_masks, deterministic=True
            )

            action_id = actions.item()
            print(
                "observation received to produce action_id is "
                + str(observation["pointgoal"])
            )
            print("action_id from net is " + str(actions.item()))

            t_prev_update = time.time()
            vel_msg = Twist()
            vel_msg.linear.x = 0
            vel_msg.linear.y = 0
            vel_msg.linear.z = 0
            vel_msg.angular.x = 0
            vel_msg.angular.y = 0
            vel_msg.angular.z = 0
            if action_id == 0:
                vel_msg.linear.x = 0.25 / 4
                pub_vel.publish(vel_msg)
            elif action_id == 1:
                vel_msg.angular.z = 10 / 180 * 3.1415926
                pub_vel.publish(vel_msg)
            elif action_id == 2:
                vel_msg.angular.z = -10 / 180 * 3.1415926
                pub_vel.publish(vel_msg)
            else:
                pub_vel.publish(vel_msg)
                sub.unregister()
                print("NN finished navigation task")

    sub = rospy.Subscriber(
        "depth_and_pointgoal", numpy_msg(Floats), transform_callback, queue_size=1
    )
    rospy.spin()
def eval_checkpoint(checkpoint_path, args, writer, cur_ckpt_idx=0):
    env_configs = []
    baseline_configs = []
    device = torch.device("cuda", args.pth_gpu_id)

    for _ in range(args.num_processes):
        config_env = get_config(config_paths=args.task_config)
        config_env.defrost()
        config_env.DATASET.SPLIT = "val"

        agent_sensors = args.sensors.strip().split(",")
        for sensor in agent_sensors:
            assert sensor in ["RGB_SENSOR", "DEPTH_SENSOR"]
        config_env.SIMULATOR.AGENT_0.SENSORS = agent_sensors
        if args.video_option:
            config_env.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config_env.TASK.MEASUREMENTS.append("COLLISIONS")
        config_env.freeze()
        env_configs.append(config_env)

        config_baseline = cfg_baseline()
        baseline_configs.append(config_baseline)

    assert len(baseline_configs) > 0, "empty list of datasets"

    envs = habitat.VectorEnv(
        make_env_fn=make_env_fn,
        env_fn_args=tuple(
            tuple(
                zip(env_configs, baseline_configs, range(args.num_processes))
            )
        ),
    )

    ckpt = torch.load(checkpoint_path, map_location=device)

    actor_critic = Policy(
        observation_space=envs.observation_spaces[0],
        action_space=envs.action_spaces[0],
        hidden_size=512,
        goal_sensor_uuid=env_configs[0].TASK.GOAL_SENSOR_UUID,
    )
    actor_critic.to(device)

    ppo = PPO(
        actor_critic=actor_critic,
        clip_param=0.1,
        ppo_epoch=4,
        num_mini_batch=32,
        value_loss_coef=0.5,
        entropy_coef=0.01,
        lr=2.5e-4,
        eps=1e-5,
        max_grad_norm=0.5,
    )

    ppo.load_state_dict(ckpt["state_dict"])

    actor_critic = ppo.actor_critic

    observations = envs.reset()
    batch = batch_obs(observations)
    for sensor in batch:
        batch[sensor] = batch[sensor].to(device)

    episode_rewards = torch.zeros(envs.num_envs, 1, device=device)
    episode_spls = torch.zeros(envs.num_envs, 1, device=device)
    episode_success = torch.zeros(envs.num_envs, 1, device=device)
    episode_counts = torch.zeros(envs.num_envs, 1, device=device)
    current_episode_reward = torch.zeros(envs.num_envs, 1, device=device)

    test_recurrent_hidden_states = torch.zeros(
        args.num_processes, args.hidden_size, device=device
    )
    not_done_masks = torch.zeros(args.num_processes, 1, device=device)
    stats_episodes = set()

    rgb_frames = None
    if args.video_option:
        rgb_frames = [[]] * args.num_processes
        os.makedirs(args.video_dir, exist_ok=True)

    while episode_counts.sum() < args.count_test_episodes:
        current_episodes = envs.current_episodes()

        with torch.no_grad():
            _, actions, _, test_recurrent_hidden_states = actor_critic.act(
                batch,
                test_recurrent_hidden_states,
                not_done_masks,
                deterministic=False,
            )

        outputs = envs.step([a[0].item() for a in actions])

        observations, rewards, dones, infos = [list(x) for x in zip(*outputs)]
        batch = batch_obs(observations)
        for sensor in batch:
            batch[sensor] = batch[sensor].to(device)

        not_done_masks = torch.tensor(
            [[0.0] if done else [1.0] for done in dones],
            dtype=torch.float,
            device=device,
        )

        for i in range(not_done_masks.shape[0]):
            if not_done_masks[i].item() == 0:
                episode_spls[i] += infos[i]["spl"]
                if infos[i]["spl"] > 0:
                    episode_success[i] += 1

        rewards = torch.tensor(
            rewards, dtype=torch.float, device=device
        ).unsqueeze(1)
        current_episode_reward += rewards
        episode_rewards += (1 - not_done_masks) * current_episode_reward
        episode_counts += 1 - not_done_masks
        current_episode_reward *= not_done_masks

        next_episodes = envs.current_episodes()
        envs_to_pause = []
        n_envs = envs.num_envs
        for i in range(n_envs):
            if next_episodes[i].episode_id in stats_episodes:
                envs_to_pause.append(i)

            # episode ended
            if not_done_masks[i].item() == 0:
                stats_episodes.add(current_episodes[i].episode_id)
                if args.video_option:
                    generate_video(
                        args,
                        rgb_frames[i],
                        current_episodes[i].episode_id,
                        cur_ckpt_idx,
                        infos[i]["spl"],
                        writer,
                    )
                    rgb_frames[i] = []

            # episode continues
            elif args.video_option:
                frame = observations_to_image(observations[i], infos[i])
                rgb_frames[i].append(frame)

        # stop tracking ended episodes if they exist
        if len(envs_to_pause) > 0:
            state_index = list(range(envs.num_envs))
            for idx in reversed(envs_to_pause):
                state_index.pop(idx)
                envs.pause_at(idx)

            # indexing along the batch dimensions
            test_recurrent_hidden_states = test_recurrent_hidden_states[
                :, state_index
            ]
            not_done_masks = not_done_masks[state_index]
            current_episode_reward = current_episode_reward[state_index]

            for k, v in batch.items():
                batch[k] = v[state_index]

            if args.video_option:
                rgb_frames = [rgb_frames[i] for i in state_index]

    episode_reward_mean = (episode_rewards / episode_counts).mean().item()
    episode_spl_mean = (episode_spls / episode_counts).mean().item()
    episode_success_mean = (episode_success / episode_counts).mean().item()

    logger.info("Average episode reward: {:.6f}".format(episode_reward_mean))
    logger.info("Average episode success: {:.6f}".format(episode_success_mean))
    logger.info("Average episode SPL: {:.6f}".format(episode_spl_mean))

    writer.add_scalars(
        "eval_reward", {"average reward": episode_reward_mean}, cur_ckpt_idx
    )
    writer.add_scalars(
        "eval_SPL", {"average SPL": episode_spl_mean}, cur_ckpt_idx
    )
    writer.add_scalars(
        "eval_success", {"average success": episode_success_mean}, cur_ckpt_idx
    )
Beispiel #6
0
ckpt = torch.load("/home/bruce/NSERC_2019/habitat-api/data/checkpoints/ckpt.2.pth", map_location=device)

actor_critic = Policy(
    observation_space=envs.observation_spaces[0],
    action_space=envs.action_spaces[0],
    hidden_size=512,
)
actor_critic.to(device)

ppo = PPO(
    actor_critic=actor_critic,
    clip_param=0.1,
    ppo_epoch=4,
    num_mini_batch=32,
    value_loss_coef=0.5,
    entropy_coef=0.01,
    lr=2.5e-4,
    eps=1e-5,
    max_grad_norm=0.5,
)

ppo.load_state_dict(ckpt["state_dict"])

actor_critic = ppo.actor_critic

observations = envs.reset()
batch = batch_obs(observations)
for sensor in batch:
    batch[sensor] = batch[sensor].to(device)
def eval_checkpoint(checkpoint_path, args, writer, cur_ckpt_idx=0):
    env_configs = []
    baseline_configs = []
    device = torch.device("cuda", args.pth_gpu_id)

    for _ in range(args.num_processes):
        config_env = get_config(config_paths=args.task_config)
        config_env.defrost()
        config_env.DATASET.SPLIT = "val"

        agent_sensors = args.sensors.strip().split(",")
        for sensor in agent_sensors:
            assert sensor in ["RGB_SENSOR", "DEPTH_SENSOR"]
        config_env.SIMULATOR.AGENT_0.SENSORS = agent_sensors
        if args.video_option:
            config_env.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config_env.TASK.MEASUREMENTS.append("COLLISIONS")
        config_env.freeze()
        env_configs.append(config_env)

        config_baseline = cfg_baseline()
        baseline_configs.append(config_baseline)

    assert len(baseline_configs) > 0, "empty list of datasets"

    envs = habitat.VectorEnv(
        make_env_fn=make_env_fn,
        env_fn_args=tuple(
            tuple(zip(env_configs, baseline_configs,
                      range(args.num_processes)))),
    )

    ckpt = torch.load(checkpoint_path, map_location=device)

    actor_critic = Policy(
        observation_space=envs.observation_spaces[0],
        action_space=envs.action_spaces[0],
        hidden_size=512,
        goal_sensor_uuid=env_configs[0].TASK.GOAL_SENSOR_UUID,
    )
    actor_critic.to(device)

    ppo = PPO(
        actor_critic=actor_critic,
        clip_param=0.1,
        ppo_epoch=4,
        num_mini_batch=32,
        value_loss_coef=0.5,
        entropy_coef=0.01,
        lr=2.5e-4,
        eps=1e-5,
        max_grad_norm=0.5,
    )

    ppo.load_state_dict(ckpt["state_dict"])

    actor_critic = ppo.actor_critic

    observations = envs.reset()
    batch = batch_obs(observations)
    for sensor in batch:
        batch[sensor] = batch[sensor].to(device)

    current_episode_reward = torch.zeros(envs.num_envs, 1, device=device)

    test_recurrent_hidden_states = torch.zeros(args.num_processes,
                                               args.hidden_size,
                                               device=device)
    not_done_masks = torch.zeros(args.num_processes, 1, device=device)
    stats_episodes = dict()  # dict of dicts that stores stats per episode

    while episode_counts.sum() < args.count_test_episodes:
        # test_recurrent_hidden_states_list.append(test_recurrent_hidden_states)
        # pickle_out = open("hab_recurrent_states.pickle","wb")
        # pickle.dump(test_recurrent_hidden_states_list, pickle_out)
        # pickle_out.close()
        # obs_list.append(observations[0])
        # pickle_out = open("hab_obs_list.pickle","wb")
        # pickle.dump(obs_list, pickle_out)
        # pickle_out.close()

        # mask_list.append(not_done_masks)
        # pickle_out = open("hab_mask_list.pickle","wb")
        # pickle.dump(mask_list, pickle_out)
        # pickle_out.close()

        with torch.no_grad():
            _, actions, _, test_recurrent_hidden_states = actor_critic.act(
                batch,
                test_recurrent_hidden_states,
                not_done_masks,
                deterministic=True,
            )

        print("action_id is " + str(actions.item()))
        print('point goal is ' + str(observations[0]['pointgoal']))

        outputs = envs.step([a[0].item() for a in actions])

        observations, rewards, dones, infos = [list(x) for x in zip(*outputs)]

        #for visualizing where robot is going
        #cv2.imshow("RGB", transform_rgb_bgr(observations[0]["rgb"]))
        cv2.imshow("Depth", observations[0]["depth"])
        cv2.waitKey(100)
        time.sleep(0.2)

        batch = batch_obs(observations)
        for sensor in batch:
            batch[sensor] = batch[sensor].to(device)

        not_done_masks = torch.tensor(
            [[0.0] if done else [1.0] for done in dones],
            dtype=torch.float,
            device=device,
        )

        rewards = torch.tensor(rewards, dtype=torch.float,
                               device=device).unsqueeze(1)
        current_episode_reward += rewards
        next_episodes = envs.current_episodes()
        envs_to_pause = []
        n_envs = envs.num_envs
        for i in range(n_envs):
            if (
                    next_episodes[i].scene_id,
                    next_episodes[i].episode_id,
            ) in stats_episodes:
                envs_to_pause.append(i)

            # episode ended
            if not_done_masks[i].item() == 0:
                episode_stats = dict()
                episode_stats["spl"] = infos[i]["spl"]
                episode_stats["success"] = int(infos[i]["spl"] > 0)
                episode_stats["reward"] = current_episode_reward[i].item()
                current_episode_reward[i] = 0
                # use scene_id + episode_id as unique id for storing stats
                stats_episodes[(
                    current_episodes[i].scene_id,
                    current_episodes[i].episode_id,
                )] = episode_stats
                if args.video_option:
                    generate_video(
                        args,
                        rgb_frames[i],
                        current_episodes[i].episode_id,
                        cur_ckpt_idx,
                        infos[i]["spl"],
                        writer,
                    )
                    rgb_frames[i] = []

            # episode continues
            elif args.video_option:
                frame = observations_to_image(observations[i], infos[i])
                rgb_frames[i].append(frame)

        # pausing envs with no new episode
        if len(envs_to_pause) > 0:
            state_index = list(range(envs.num_envs))
            for idx in reversed(envs_to_pause):
                state_index.pop(idx)
                envs.pause_at(idx)

            # indexing along the batch dimensions
            test_recurrent_hidden_states = test_recurrent_hidden_states[
                state_index]
            not_done_masks = not_done_masks[state_index]
            current_episode_reward = current_episode_reward[state_index]

            for k, v in batch.items():
                batch[k] = v[state_index]

            if args.video_option:
                rgb_frames = [rgb_frames[i] for i in state_index]

    aggregated_stats = dict()
    for stat_key in next(iter(stats_episodes.values())).keys():
        aggregated_stats[stat_key] = sum(
            [v[stat_key] for v in stats_episodes.values()])
    num_episodes = len(stats_episodes)

    episode_reward_mean = aggregated_stats["reward"] / num_episodes
    episode_spl_mean = aggregated_stats["spl"] / num_episodes
    episode_success_mean = aggregated_stats["success"] / num_episodes

    logger.info("Average episode reward: {:.6f}".format(episode_reward_mean))
    logger.info("Average episode success: {:.6f}".format(episode_success_mean))
    logger.info("Average episode SPL: {:.6f}".format(episode_spl_mean))

    writer.add_scalars("eval_reward", {"average reward": episode_reward_mean},
                       cur_ckpt_idx)
    writer.add_scalars("eval_SPL", {"average SPL": episode_spl_mean},
                       cur_ckpt_idx)
    writer.add_scalars("eval_success",
                       {"average success": episode_success_mean}, cur_ckpt_idx)
Beispiel #8
0
def main():
    parser = argparse.ArgumentParser(
        description='Run robust control experiments.')
    parser.add_argument('--baseLR', type=float, default=1e-3,
                        help='learning rate for non-projected DPS')
    parser.add_argument('--robustLR', type=float, default=1e-4,
                        help='learning rate for projected DPS')
    parser.add_argument('--alpha', type=float, default=0.001,
                        help='exponential stability coefficient')
    parser.add_argument('--gamma', type=float, default=20,
                        help='bound on L2 gain of disturbance-to-output map (for H_inf control)')
    parser.add_argument('--epochs', type=int, default=1000,
                        help='max epochs')
    parser.add_argument('--test_frequency', type=int, default=20,
                        help='frequency of testing during training')
    parser.add_argument('--T', type=float, default=2,
                        help='time horizon in seconds')
    parser.add_argument('--dt', type=float, default=0.01,
                        help='time increment')
    parser.add_argument('--testSetSz', type=int, default=50,
                        help='size of test set')
    parser.add_argument('--holdSetSz', type=int, default=50,
                        help='size of holdout set')
    parser.add_argument('--trainBatchSz', type=int, default=20,
                        help='batch size for training')
    parser.add_argument('--stepType', type=str,
                        choices=['euler', 'RK4', 'scipy'], default='RK4',
                        help='method for taking steps during training')
    parser.add_argument('--testStepType', type=str,
                        choices=['euler', 'RK4', 'scipy'], default='RK4',
                        help='method for taking steps during testing')
    parser.add_argument('--env', type=str,
                        choices=['random_nldi-d0', 'random_nldi-dnonzero', 'random_pldi_env',
                        'random_hinf_env', 'cartpole', 'quadrotor', 'microgrid'],
                        default='random_nldi-d0',
                        help='environment')
    parser.add_argument('--envRandomSeed', type=int, default=10,
                        help='random seed used to construct the environment')
    parser.add_argument('--save', type=str,
                        help='prefix to add to save path')
    parser.add_argument('--gpu', type=int, default=0,
                        help='prefix to add to save path')
    parser.add_argument('--evaluate', type=str,
                        help='instead of training, evaluate the models from a given directory'
                             ' (remember to use the same random seed)')
    args = parser.parse_args()

    dt = args.dt
    save_sub = '{}+alpha{}+gamma{}+testSz{}+holdSz{}+trainBatch{}+baselr{}+robustlr{}+T{}+stepType{}+testStepType{}+seed{}+dt{}'.format(
        args.env, args.alpha, args.gamma, args.testSetSz, args.holdSetSz,
        args.trainBatchSz, args.baseLR, args.robustLR, args.T,
        args.stepType, args.testStepType, args.envRandomSeed, dt)
    if args.save is not None:
        save = os.path.join('results', '{}+{}'.format(args.save, save_sub))
    else:
        save = os.path.join('results', save_sub)
    trained_model_dir = os.path.join(save, 'trained_models')
    if not os.path.exists(trained_model_dir):
        os.makedirs(trained_model_dir)
    setproctitle.setproctitle(save_sub)
    
    device = torch.device('cuda:%d' % args.gpu if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        torch.cuda.set_device(args.gpu)
    
    # Setup
    isD0 = (args.env == 'random_nldi-d0') or (args.env == 'quadrotor')  # no u dependence in disturbance bound
    problem_type = 'nldi'
    if 'random_nldi' in args.env:
        env = RandomNLDIEnv(isD0=isD0, random_seed=args.envRandomSeed, device=device)
    elif args.env == 'random_pldi_env':
        env = RandomPLDIEnv(random_seed=args.envRandomSeed, device=device)
        problem_type = 'pldi'
    elif args.env == 'random_hinf_env':
        env = RandomHinfEnv(T=args.T, random_seed=args.envRandomSeed, device=device)
        problem_type = 'hinf'
    elif args.env == 'cartpole':
        env = CartPoleEnv(random_seed=args.envRandomSeed, device=device)
    elif args.env == 'quadrotor':
        env = QuadrotorEnv(random_seed=args.envRandomSeed, device=device)
    elif args.env == 'microgrid':
        env = MicrogridEnv(random_seed=args.envRandomSeed, device=device)
    else:
        raise ValueError('No environment named %s' % args.env)
    evaluate_dir = args.evaluate
    evaluate = evaluate_dir is not None

    # Test and holdout set of states
    torch.manual_seed(17)
    x_test = env.gen_states(num_states=args.testSetSz, device=device)
    x_hold = env.gen_states(num_states=args.holdSetSz, device=device)
    num_episode_steps = int(args.T / dt)

    if problem_type == 'nldi':
        A, B, G, C, D, Q, R = env.get_nldi_linearization()
        state_dim = A.shape[0]
        action_dim = B.shape[1]

        # Get LQR solutions
        Kct, Pct = get_lqr_tensors(A, B, Q, R, args.alpha, device)

        Kr, Sr = get_robust_lqr_sol(*(v.cpu().numpy() for v in (A, B, G, C, D, Q, R)), args.alpha)
        Krt = torch.tensor(Kr, device=device, dtype=TORCH_DTYPE)
        Prt = torch.tensor(np.linalg.inv(Sr), device=device, dtype=TORCH_DTYPE)
        stable_projection = pm.StableNLDIProjection(Prt, A, B, G, C, D, args.alpha, isD0)

        disturb_model = dm.MultiNLDIDisturbModel(x_test.shape[0], C, D, state_dim, action_dim, env.wp)
        disturb_model.to(device=device, dtype=TORCH_DTYPE)

    elif problem_type == 'pldi':
        A, B, Q, R = env.get_pldi_linearization()
        state_dim = A.shape[1]
        action_dim = B.shape[2]

        # Get LQR solutions
        Kct, Pct = get_lqr_tensors(A.mean(0), B.mean(0), Q, R, args.alpha, device)

        Kr, Sr = get_robust_pldi_policy(*(v.cpu().numpy() for v in (A, B, Q, R)), args.alpha)
        Krt = torch.tensor(Kr, device=device, dtype=TORCH_DTYPE)
        Prt = torch.tensor(np.linalg.inv(Sr), device=device, dtype=TORCH_DTYPE)
        stable_projection = pm.StablePLDIProjection(Prt, A, B)

        disturb_model = dm.MultiPLDIDisturbModel(x_test.shape[0], state_dim, action_dim, env.L)
        disturb_model.to(device=device, dtype=TORCH_DTYPE)
    
    elif problem_type == 'hinf':
        A, B, G, Q, R = env.get_hinf_linearization()
        state_dim = A.shape[0]
        action_dim = B.shape[1]

        # Get LQR solutions
        Kct, Pct = get_lqr_tensors(A, B, Q, R, args.alpha, device)

        Kr, Sr, mu = get_robust_hinf_policy(*(v.cpu().numpy() for v in (A, B, G, Q, R)), args.alpha, args.gamma)
        Krt = torch.tensor(Kr, device=device, dtype=TORCH_DTYPE)
        Prt = torch.tensor(np.linalg.inv(Sr), device=device, dtype=TORCH_DTYPE)
        stable_projection = pm.StableHinfProjection(Prt, A, B, G, Q, R, args.alpha, args.gamma, 1/mu)

        disturb_model = dm.MultiHinfDisturbModel(x_test.shape[0], state_dim, action_dim, env.wp, args.T)
        disturb_model.to(device=device, dtype=TORCH_DTYPE)

    else:
        raise ValueError('No problem type named %s' % problem_type)

    adv_disturb_model = dm.MBAdvDisturbModel(env, None, disturb_model, dt, horizon=num_episode_steps//5, update_freq=num_episode_steps//20)
    env.adversarial_disturb_f = adv_disturb_model

    ###########################################################
    # LQR baselines
    ###########################################################

    ### Vanilla LQR (i.e., non-robust, exponentially stable)
    pi_custom_lqr = lambda x: x @ Kct.T
    adv_disturb_model.set_policy(pi_custom_lqr)

    custom_lqr_perf = eval_model(x_test, pi_custom_lqr, env,
                               step_type=args.testStepType, T=args.T, dt=dt)
    write_results(custom_lqr_perf, 'LQR', save)
    custom_lqr_perf = eval_model(x_test, pi_custom_lqr, env,
                                 step_type=args.testStepType, T=args.T, dt=dt, adversarial=True)
    write_results(custom_lqr_perf, 'LQR-adv', save)

    ### Robust LQR
    pi_robust_lqr = lambda x: x @ Krt.T
    adv_disturb_model.set_policy(pi_robust_lqr)

    robust_lqr_perf = eval_model(x_test, pi_robust_lqr, env,
                                 step_type=args.testStepType, T=args.T, dt=dt)
    write_results(robust_lqr_perf, 'Robust LQR', save)
    robust_lqr_perf = eval_model(x_test, pi_robust_lqr, env,
                                 step_type=args.testStepType, T=args.T, dt=dt, adversarial=True)
    write_results(robust_lqr_perf, 'Robust LQR-adv', save)


    ###########################################################
    # Model-based planning methods
    ###########################################################

    ### Non-robust MBP (starting with robust LQR solution)
    pi_mbp = pm.MBPPolicy(Krt, state_dim, action_dim)
    pi_mbp.to(device=device, dtype=TORCH_DTYPE)
    adv_disturb_model.set_policy(pi_mbp)

    if evaluate:
        pi_mbp.load_state_dict(torch.load(os.path.join(evaluate_dir, 'mbp.pt')))
    else:
        pi_mbp_dict, train_losses, hold_losses, test_losses, test_losses_adv, stop_epoch = \
            train(pi_mbp, x_test, x_hold, env,
                  lr=args.baseLR, batch_size=args.trainBatchSz, epochs=args.epochs, T=args.T, dt=dt, step_type=args.stepType,
                  test_frequency=args.test_frequency, save_dir=save, model_name='mbp', device=device)
        save_results(train_losses, hold_losses, test_losses, test_losses_adv, save, 'mbp', pi_mbp_dict, epoch=stop_epoch,
                     is_final=True)
        torch.save(pi_mbp_dict, os.path.join(trained_model_dir, 'mbp.pt'))

    pi_mbp_perf = eval_model(x_test, pi_mbp, env,
                            step_type=args.testStepType, T=args.T, dt=dt)
    write_results(pi_mbp_perf, 'MBP', save)
    pi_mbp_perf = eval_model(x_test, pi_mbp, env,
                            step_type=args.testStepType, T=args.T, dt=dt, adversarial=True)
    write_results(pi_mbp_perf, 'MBP-adv', save)


    ### Robust MBP (starting with robust LQR solution)
    pi_robust_mbp = pm.StablePolicy(pm.MBPPolicy(Krt, state_dim, action_dim), stable_projection)
    pi_robust_mbp.to(device=device, dtype=TORCH_DTYPE)
    adv_disturb_model.set_policy(pi_robust_mbp)

    if evaluate:
        pi_robust_mbp.load_state_dict(torch.load(os.path.join(evaluate_dir, 'robust_mbp.pt')))
    else:
        pi_robust_mbp_dict, train_losses, hold_losses, test_losses, test_losses_adv, stop_epoch = \
            train(pi_robust_mbp, x_test, x_hold, env,
                  lr=args.robustLR, batch_size=args.trainBatchSz, epochs=args.epochs, T=args.T, dt=dt, step_type=args.stepType,
                  test_frequency=args.test_frequency, save_dir=save, model_name='robust_mbp', device=device)
        save_results(train_losses, hold_losses, test_losses, test_losses_adv, save, 'robust_mbp', pi_robust_mbp_dict, epoch=stop_epoch,
                     is_final=True)
        torch.save(pi_robust_mbp_dict, os.path.join(trained_model_dir, 'robust_mbp.pt'))

    pi_robust_mbp_perf = eval_model(x_test, pi_robust_mbp, env,
                                   step_type=args.testStepType, T=args.T, dt=dt)
    write_results(pi_robust_mbp_perf, 'Robust MBP', save)
    pi_robust_mbp_perf = eval_model(x_test, pi_robust_mbp, env,
                                   step_type=args.testStepType, T=args.T, dt=dt, adversarial=True)
    write_results(pi_robust_mbp_perf, 'Robust MBP-adv', save)


    ###########################################################
    # RL methods
    ###########################################################

    if 'random_nldi' in args.env:
        if isD0:
            rmax = 1000
        else:
            rmax = 1000
    elif args.env == 'random_pldi_env':
        rmax = 10
    elif args.env == 'random_hinf_env':
        rmax = 1000
    elif args.env == 'cartpole':
        rmax = 10
    elif args.env == 'quadrotor':
        rmax = 1000
    elif args.env == 'microgrid':
        rmax = 10
    else:
        raise ValueError('No environment named %s' % args.env)

    rl_args = arguments.get_args()
    linear_controller_K = Krt
    linear_controller_P = Prt
    linear_transform = lambda u, x: u + x @ linear_controller_K.T


    ### Vanilla and robust PPO
    base_ppo_perfs = []
    base_ppo_adv_perfs = []
    robust_ppo_perfs = []
    robust_ppo_adv_perfs = []
    for seed in range(1):
        for robust in [False, True]:
            torch.manual_seed(seed)

            if robust:
                # stable_projection = pm.StableNLDIProjection(linear_controller_P, A, B, G, C, D, args.alpha, isD0=isD0)
                action_transform = lambda u, x: stable_projection.project_action(linear_transform(u, x), x)
            else:
                action_transform = linear_transform

            envs = RLWrapper(env, state_dim, action_dim, gamma=rl_args.gamma,
                             dt=dt, rmax=rmax, step_type='RK4', action_transform=action_transform,
                             num_envs=rl_args.num_processes, device=device)
            eval_envs = RLWrapper(env, state_dim, action_dim, gamma=rl_args.gamma,
                                  dt=dt, rmax=rmax, step_type='RK4', action_transform=action_transform,
                                  num_envs=args.testSetSz, device=device)

            actor_critic = Policy(
                envs.observation_space.shape,
                envs.action_space,
                base_kwargs={'recurrent': False})
            actor_critic.to(device=device, dtype=TORCH_DTYPE)
            agent = PPO(
                actor_critic,
                rl_args.clip_param,
                rl_args.ppo_epoch,
                rl_args.num_mini_batch,
                rl_args.value_loss_coef,
                rl_args.entropy_coef,
                lr=rl_args.lr,
                eps=rl_args.rms_prop_eps,
                max_grad_norm=rl_args.max_grad_norm,
                use_linear_lr_decay=rl_args.use_linear_lr_decay)
            rollouts = RolloutStorage(num_episode_steps, rl_args.num_processes,
                                      envs.observation_space.shape, envs.action_space,
                                      actor_critic.recurrent_hidden_state_size)

            ppo_pi = lambda x: action_transform(actor_critic.act(x, None, None, deterministic=True)[1], x)
            adv_disturb_model.set_policy(ppo_pi)

            if evaluate:
                actor_critic.load_state_dict(torch.load(os.path.join(evaluate_dir,
                                                                     'robust_ppo.pt' if robust else 'ppo.pt')))
            else:
                hold_costs, test_costs, adv_test_costs =\
                    trainer.train(agent, envs, rollouts, device, rl_args,
                                  eval_envs=eval_envs, x_hold=x_hold, x_test=x_test, num_episode_steps=num_episode_steps,
                                  save_dir=os.path.join(save, 'robust_ppo' if robust else 'ppo'),
                                  save_extension='%d' % seed)
                save_results(np.zeros_like(hold_costs), hold_costs, test_costs, adv_test_costs, save,
                             'robust_ppo' if robust else 'ppo', actor_critic.state_dict(),
                             epoch=rl_args.num_env_steps, is_final=True)
                torch.save(actor_critic.state_dict(), os.path.join(trained_model_dir,
                                                                   'robust_ppo.pt' if robust else 'ppo.pt'))

            ppo_perf = eval_model(x_test, ppo_pi, env,
                                  step_type=args.testStepType, T=args.T, dt=dt)
            ppo_adv_perf = eval_model(x_test, ppo_pi, env,
                                      step_type=args.testStepType, T=args.T, dt=dt, adversarial=True)

            if robust:
                robust_ppo_perfs.append(ppo_perf.item())
                robust_ppo_adv_perfs.append(ppo_adv_perf.item())
            else:
                base_ppo_perfs.append(ppo_perf.item())
                base_ppo_adv_perfs.append(ppo_adv_perf.item())

    write_results(base_ppo_perfs, 'PPO', save)
    write_results(robust_ppo_perfs, 'Robust PPO', save)
    write_results(base_ppo_adv_perfs, 'PPO-adv', save)
    write_results(robust_ppo_adv_perfs, 'Robust PPO-adv', save)


    # RARL PPO baseline
    adv_ppo_perfs = []
    adv_ppo_adv_perfs = []
    seed = 0
    torch.manual_seed(seed)

    action_transform = linear_transform

    envs = RLWrapper(env, state_dim, action_dim, gamma=rl_args.gamma,
                     dt=dt, rmax=rmax, step_type='RK4', action_transform=action_transform,
                     num_envs=rl_args.num_processes, device=device, rarl=True)
    eval_envs = RLWrapper(env, state_dim, action_dim, gamma=rl_args.gamma,
                          dt=dt, rmax=rmax, step_type='RK4', action_transform=action_transform,
                          num_envs=args.testSetSz, device=device)

    protagornist_ac = Policy(
        envs.observation_space.shape,
        envs.action_space,
        base_kwargs={'recurrent': False})
    protagornist_ac.to(device=device, dtype=TORCH_DTYPE)
    adversary_ac = Policy(
        envs.observation_space.shape,
        envs.disturb_space,
        base_kwargs={'recurrent': False})
    adversary_ac.to(device=device, dtype=TORCH_DTYPE)
    agent = RARLPPO(
        protagornist_ac,
        adversary_ac,
        rl_args.clip_param,
        rl_args.ppo_epoch,
        rl_args.num_mini_batch,
        rl_args.value_loss_coef,
        rl_args.entropy_coef,
        lr=rl_args.lr,
        eps=rl_args.rms_prop_eps,
        max_grad_norm=rl_args.max_grad_norm,
        use_linear_lr_decay=rl_args.use_linear_lr_decay)
    action_space = spaces.Box(low=0, high=1,
                              shape=(envs.action_space.shape[0]+envs.disturb_space.shape[0],), dtype=NUMPY_DTYPE)
    rollouts = RolloutStorage(num_episode_steps, rl_args.num_processes,
                              envs.observation_space.shape, action_space,
                              protagornist_ac.recurrent_hidden_state_size + adversary_ac.recurrent_hidden_state_size,
                              rarl=True)

    ppo_pi = lambda x: action_transform(protagornist_ac.act(x, None, None, deterministic=True)[1], x)
    adv_disturb_model.set_policy(ppo_pi)

    if evaluate:
        agent.load(evaluate_dir)
    else:
        hold_costs, test_costs, adv_test_costs = \
            trainer.train(agent, envs, rollouts, device, rl_args,
                          eval_envs=eval_envs, x_hold=x_hold, x_test=x_test,
                          num_episode_steps=num_episode_steps,
                          save_dir=os.path.join(save, 'rarl_ppo'),
                          save_extension='%d' % seed)
        save_results(np.zeros_like(hold_costs), hold_costs, test_costs, adv_test_costs, save,
                     'rarl_ppo', protagornist_ac.state_dict(),
                     epoch=rl_args.num_env_steps, is_final=True)
        agent.save(trained_model_dir)
    env.disturb_f.disturbance = None

    ppo_perf = eval_model(x_test, ppo_pi, env,
                          step_type=args.testStepType, T=args.T, dt=dt)
    ppo_adv_perf = eval_model(x_test, ppo_pi, env,
                              step_type=args.testStepType, T=args.T, dt=dt, adversarial=True)

    adv_ppo_perfs.append(ppo_perf.item())
    adv_ppo_adv_perfs.append(ppo_adv_perf.item())

    write_results(adv_ppo_perfs, 'RARL PPO', save)
    write_results(adv_ppo_adv_perfs, 'RARL PPO-adv', save)


    ###########################################################
    # MPC baselines
    ###########################################################

    ### Robust MPC (not implemented for H_infinity settings)
    if problem_type != 'hinf':
        if problem_type == 'nldi':
            robust_mpc_model = rmpc.RobustNLDIMPC(A, B, G, C, D, Q, R, Krt, device)
        else:
            robust_mpc_model = rmpc.RobustPLDIMPC(A, B, Q, R, Krt, device)

        pi_robust_mpc = robust_mpc_model.get_action
        adv_disturb_model.set_policy(pi_robust_mpc)

        robust_mpc_perf = eval_model(x_test, pi_robust_mpc, env,
                                step_type=args.testStepType, T=args.T, dt=dt, adversarial=True)
        write_results(robust_mpc_perf, 'Robust MPC-adv', save)